Shokoufehhh
commited on
Upload 40 files
Browse files- .gitignore +335 -0
- LICENSE +21 -0
- README.md +143 -0
- calc_metrics.py +67 -0
- diffusion_process.png +0 -0
- enhancement.py +84 -0
- logs/.keep +1 -0
- preprocessing/create_wsj0_chime3.py +130 -0
- preprocessing/create_wsj0_qut.py +172 -0
- preprocessing/create_wsj0_reverb.py +118 -0
- requirements.txt +26 -0
- requirements_version.txt +25 -0
- sgmse/backbones/__init__.py +6 -0
- sgmse/backbones/dcunet.py +627 -0
- sgmse/backbones/ncsnpp.py +419 -0
- sgmse/backbones/ncsnpp_48k.py +424 -0
- sgmse/backbones/ncsnpp_utils/layers.py +662 -0
- sgmse/backbones/ncsnpp_utils/layerspp.py +274 -0
- sgmse/backbones/ncsnpp_utils/normalization.py +215 -0
- sgmse/backbones/ncsnpp_utils/op/__init__.py +1 -0
- sgmse/backbones/ncsnpp_utils/op/fused_act.py +97 -0
- sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp +21 -0
- sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu +99 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp +23 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py +203 -0
- sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu +369 -0
- sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py +257 -0
- sgmse/backbones/ncsnpp_utils/utils.py +189 -0
- sgmse/backbones/shared.py +123 -0
- sgmse/data_module.py +236 -0
- sgmse/model.py +253 -0
- sgmse/sampling/__init__.py +143 -0
- sgmse/sampling/correctors.py +96 -0
- sgmse/sampling/predictors.py +76 -0
- sgmse/sdes.py +310 -0
- sgmse/util/inference.py +64 -0
- sgmse/util/other.py +141 -0
- sgmse/util/registry.py +34 -0
- sgmse/util/tensors.py +16 -0
- train.py +103 -0
.gitignore
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,macos,linux,emacs,windows,jupyternotebooks
|
3 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,macos,linux,emacs,windows,jupyternotebooks
|
4 |
+
|
5 |
+
### Emacs ###
|
6 |
+
# -*- mode: gitignore; -*-
|
7 |
+
*~
|
8 |
+
\#*\#
|
9 |
+
/.emacs.desktop
|
10 |
+
/.emacs.desktop.lock
|
11 |
+
*.elc
|
12 |
+
auto-save-list
|
13 |
+
tramp
|
14 |
+
.\#*
|
15 |
+
|
16 |
+
# Org-mode
|
17 |
+
.org-id-locations
|
18 |
+
*_archive
|
19 |
+
|
20 |
+
# flymake-mode
|
21 |
+
*_flymake.*
|
22 |
+
|
23 |
+
# figures
|
24 |
+
figures/
|
25 |
+
listen/
|
26 |
+
enhanced/
|
27 |
+
checkpoints/
|
28 |
+
baselines/
|
29 |
+
__paths__.py
|
30 |
+
*.csv
|
31 |
+
*.pdf
|
32 |
+
*.jpg
|
33 |
+
*.png
|
34 |
+
*.html
|
35 |
+
mushra/
|
36 |
+
|
37 |
+
# eshell files
|
38 |
+
/eshell/history
|
39 |
+
/eshell/lastdir
|
40 |
+
|
41 |
+
# elpa packages
|
42 |
+
/elpa/
|
43 |
+
|
44 |
+
# reftex files
|
45 |
+
*.rel
|
46 |
+
|
47 |
+
# AUCTeX auto folder
|
48 |
+
/auto/
|
49 |
+
|
50 |
+
# cask packages
|
51 |
+
.cask/
|
52 |
+
dist/
|
53 |
+
|
54 |
+
# Flycheck
|
55 |
+
flycheck_*.el
|
56 |
+
|
57 |
+
# server auth directory
|
58 |
+
/server/
|
59 |
+
|
60 |
+
# projectiles files
|
61 |
+
.projectile
|
62 |
+
|
63 |
+
# directory configuration
|
64 |
+
.dir-locals.el
|
65 |
+
|
66 |
+
# network security
|
67 |
+
/network-security.data
|
68 |
+
|
69 |
+
|
70 |
+
### JupyterNotebooks ###
|
71 |
+
# gitignore template for Jupyter Notebooks
|
72 |
+
# website: http://jupyter.org/
|
73 |
+
|
74 |
+
.ipynb_checkpoints
|
75 |
+
*/.ipynb_checkpoints/*
|
76 |
+
|
77 |
+
# IPython
|
78 |
+
profile_default/
|
79 |
+
ipython_config.py
|
80 |
+
|
81 |
+
# Remove previous ipynb_checkpoints
|
82 |
+
# git rm -r .ipynb_checkpoints/
|
83 |
+
|
84 |
+
### Linux ###
|
85 |
+
|
86 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
87 |
+
.fuse_hidden*
|
88 |
+
|
89 |
+
# KDE directory preferences
|
90 |
+
.directory
|
91 |
+
|
92 |
+
# Linux trash folder which might appear on any partition or disk
|
93 |
+
.Trash-*
|
94 |
+
|
95 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
96 |
+
.nfs*
|
97 |
+
|
98 |
+
### macOS ###
|
99 |
+
# General
|
100 |
+
.DS_Store
|
101 |
+
.AppleDouble
|
102 |
+
.LSOverride
|
103 |
+
|
104 |
+
# Icon must end with two \r
|
105 |
+
Icon
|
106 |
+
|
107 |
+
|
108 |
+
# Thumbnails
|
109 |
+
._*
|
110 |
+
|
111 |
+
# Files that might appear in the root of a volume
|
112 |
+
.DocumentRevisions-V100
|
113 |
+
.fseventsd
|
114 |
+
.Spotlight-V100
|
115 |
+
.TemporaryItems
|
116 |
+
.Trashes
|
117 |
+
.VolumeIcon.icns
|
118 |
+
.com.apple.timemachine.donotpresent
|
119 |
+
|
120 |
+
# Directories potentially created on remote AFP share
|
121 |
+
.AppleDB
|
122 |
+
.AppleDesktop
|
123 |
+
Network Trash Folder
|
124 |
+
Temporary Items
|
125 |
+
.apdisk
|
126 |
+
|
127 |
+
### Python ###
|
128 |
+
# Byte-compiled / optimized / DLL files
|
129 |
+
__pycache__/
|
130 |
+
*.py[cod]
|
131 |
+
*$py.class
|
132 |
+
|
133 |
+
# C extensions
|
134 |
+
*.so
|
135 |
+
|
136 |
+
# Distribution / packaging
|
137 |
+
.Python
|
138 |
+
build/
|
139 |
+
develop-eggs/
|
140 |
+
downloads/
|
141 |
+
eggs/
|
142 |
+
.eggs/
|
143 |
+
lib/
|
144 |
+
lib64/
|
145 |
+
parts/
|
146 |
+
sdist/
|
147 |
+
var/
|
148 |
+
wheels/
|
149 |
+
share/python-wheels/
|
150 |
+
*.egg-info/
|
151 |
+
.installed.cfg
|
152 |
+
*.egg
|
153 |
+
MANIFEST
|
154 |
+
|
155 |
+
# PyInstaller
|
156 |
+
# Usually these files are written by a python script from a template
|
157 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
158 |
+
*.manifest
|
159 |
+
*.spec
|
160 |
+
|
161 |
+
# Installer logs
|
162 |
+
pip-log.txt
|
163 |
+
pip-delete-this-directory.txt
|
164 |
+
|
165 |
+
# Unit test / coverage reports
|
166 |
+
htmlcov/
|
167 |
+
.tox/
|
168 |
+
.nox/
|
169 |
+
.coverage
|
170 |
+
.coverage.*
|
171 |
+
.cache
|
172 |
+
nosetests.xml
|
173 |
+
coverage.xml
|
174 |
+
*.cover
|
175 |
+
*.py,cover
|
176 |
+
.hypothesis/
|
177 |
+
.pytest_cache/
|
178 |
+
cover/
|
179 |
+
|
180 |
+
# Translations
|
181 |
+
*.mo
|
182 |
+
*.pot
|
183 |
+
|
184 |
+
# Django stuff:
|
185 |
+
*.log
|
186 |
+
local_settings.py
|
187 |
+
db.sqlite3
|
188 |
+
db.sqlite3-journal
|
189 |
+
|
190 |
+
# Flask stuff:
|
191 |
+
instance/
|
192 |
+
.webassets-cache
|
193 |
+
|
194 |
+
# Scrapy stuff:
|
195 |
+
.scrapy
|
196 |
+
|
197 |
+
# Sphinx documentation
|
198 |
+
docs/_build/
|
199 |
+
|
200 |
+
# PyBuilder
|
201 |
+
.pybuilder/
|
202 |
+
target/
|
203 |
+
|
204 |
+
# Jupyter Notebook
|
205 |
+
|
206 |
+
# IPython
|
207 |
+
|
208 |
+
# pyenv
|
209 |
+
# For a library or package, you might want to ignore these files since the code is
|
210 |
+
# intended to run in multiple environments; otherwise, check them in:
|
211 |
+
# .python-version
|
212 |
+
|
213 |
+
# pipenv
|
214 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
215 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
216 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
217 |
+
# install all needed dependencies.
|
218 |
+
#Pipfile.lock
|
219 |
+
|
220 |
+
# poetry
|
221 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
222 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
223 |
+
# commonly ignored for libraries.
|
224 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
225 |
+
#poetry.lock
|
226 |
+
|
227 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
228 |
+
__pypackages__/
|
229 |
+
|
230 |
+
# Celery stuff
|
231 |
+
celerybeat-schedule
|
232 |
+
celerybeat.pid
|
233 |
+
|
234 |
+
# SageMath parsed files
|
235 |
+
*.sage.py
|
236 |
+
|
237 |
+
# Environments
|
238 |
+
.env
|
239 |
+
.venv
|
240 |
+
env/
|
241 |
+
venv/
|
242 |
+
ENV/
|
243 |
+
env.bak/
|
244 |
+
venv.bak/
|
245 |
+
|
246 |
+
# Spyder project settings
|
247 |
+
.spyderproject
|
248 |
+
.spyproject
|
249 |
+
|
250 |
+
# Rope project settings
|
251 |
+
.ropeproject
|
252 |
+
|
253 |
+
# mkdocs documentation
|
254 |
+
/site
|
255 |
+
|
256 |
+
# mypy
|
257 |
+
.mypy_cache/
|
258 |
+
.dmypy.json
|
259 |
+
dmypy.json
|
260 |
+
|
261 |
+
# Pyre type checker
|
262 |
+
.pyre/
|
263 |
+
|
264 |
+
# pytype static type analyzer
|
265 |
+
.pytype/
|
266 |
+
|
267 |
+
# Cython debug symbols
|
268 |
+
cython_debug/
|
269 |
+
|
270 |
+
# PyCharm
|
271 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
272 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
273 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
274 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
275 |
+
#.idea/
|
276 |
+
|
277 |
+
### VisualStudioCode ###
|
278 |
+
.vscode/
|
279 |
+
.vscode/*
|
280 |
+
!.vscode/settings.json
|
281 |
+
!.vscode/tasks.json
|
282 |
+
!.vscode/launch.json
|
283 |
+
!.vscode/extensions.json
|
284 |
+
!.vscode/*.code-snippets
|
285 |
+
|
286 |
+
# Local History for Visual Studio Code
|
287 |
+
.history/
|
288 |
+
|
289 |
+
# Built Visual Studio Code Extensions
|
290 |
+
*.vsix
|
291 |
+
|
292 |
+
### VisualStudioCode Patch ###
|
293 |
+
# Ignore all local history of files
|
294 |
+
.history
|
295 |
+
.ionide
|
296 |
+
|
297 |
+
# Support for Project snippet scope
|
298 |
+
|
299 |
+
### Windows ###
|
300 |
+
# Windows thumbnail cache files
|
301 |
+
Thumbs.db
|
302 |
+
Thumbs.db:encryptable
|
303 |
+
ehthumbs.db
|
304 |
+
ehthumbs_vista.db
|
305 |
+
|
306 |
+
# Dump file
|
307 |
+
*.stackdump
|
308 |
+
|
309 |
+
# Folder config file
|
310 |
+
[Dd]esktop.ini
|
311 |
+
|
312 |
+
# Recycle Bin used on file shares
|
313 |
+
$RECYCLE.BIN/
|
314 |
+
|
315 |
+
# Windows Installer files
|
316 |
+
*.cab
|
317 |
+
*.msi
|
318 |
+
*.msix
|
319 |
+
*.msm
|
320 |
+
*.msp
|
321 |
+
|
322 |
+
# Windows shortcuts
|
323 |
+
*.lnk
|
324 |
+
|
325 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,macos,linux,emacs,windows,jupyternotebooks
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
# Custom ignores:
|
330 |
+
|
331 |
+
## Logs from W&B and PyTorch Lightning
|
332 |
+
/wandb
|
333 |
+
/lightning_logs
|
334 |
+
/logs
|
335 |
+
/jobs
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Signal Processing (SP), Universität Hamburg
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
tags:
|
4 |
+
- speech-enhancement
|
5 |
+
- dereverberation
|
6 |
+
- diffusion-models
|
7 |
+
- generative-models
|
8 |
+
- pytorch
|
9 |
+
- audio-processing
|
10 |
+
license: mit
|
11 |
+
datasets:
|
12 |
+
- VoiceBank-DEMAND
|
13 |
+
- WSJ0-CHiME3
|
14 |
+
- WSJ0-REVERB
|
15 |
+
- EARS-WHAM
|
16 |
+
- EARS-Reverb
|
17 |
+
model_name: speech-enhancement-dereverberation-diffusion
|
18 |
+
model_type: diffusion-based-generative-model
|
19 |
+
library_name: pytorch
|
20 |
+
inference: true
|
21 |
+
---
|
22 |
+
|
23 |
+
|
24 |
+
# Speech Enhancement and Dereverberation with Diffusion-based Generative Models
|
25 |
+
|
26 |
+
<img src="https://raw.githubusercontent.com/sp-uhh/sgmse/main/diffusion_process.png" width="500" alt="Diffusion process on a spectrogram: In the forward process noise is gradually added to the clean speech spectrogram x0, while the reverse process learns to generate clean speech in an iterative fashion starting from the corrupted signal xT.">
|
27 |
+
|
28 |
+
This repository contains the official PyTorch implementations for the papers:
|
29 |
+
|
30 |
+
- Simon Welker, Julius Richter, Timo Gerkmann, [*"Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain"*](https://www.isca-speech.org/archive/interspeech_2022/welker22_interspeech.html), ISCA Interspeech, Incheon, Korea, Sept. 2022. [[bibtex]](#citations--references)
|
31 |
+
- Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann, [*"Speech Enhancement and Dereverberation with Diffusion-Based Generative Models"*](https://ieeexplore.ieee.org/abstract/document/10149431), IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023. [[bibtex]](#citations--references)
|
32 |
+
- Julius Richter, Yi-Chiao Wu, Steven Krenn, Simon Welker, Bunlong Lay, Shinji Watanabe, Alexander Richard, Timo Gerkmann, [*"EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation"*](https://arxiv.org/abs/2406.06185), ISCA Interspecch, Kos, Greece, Sept. 2024. [[bibtex]](#citations--references)
|
33 |
+
|
34 |
+
Audio examples and supplementary materials are available on our [SGMSE project page](https://www.inf.uni-hamburg.de/en/inst/ab/sp/publications/sgmse) and [EARS project page](https://sp-uhh.github.io/ears_dataset/).
|
35 |
+
|
36 |
+
## Follow-up work
|
37 |
+
|
38 |
+
Please also check out our follow-up work with code available:
|
39 |
+
|
40 |
+
- Jean-Marie Lemercier, Julius Richter, Simon Welker, Timo Gerkmann, [*"StoRM: A Diffusion-based Stochastic Regeneration Model for Speech Enhancement and Dereverberation"*](https://ieeexplore.ieee.org/document/10180108), IEEE/ACM Transactions on Audio, Speech, Language Processing, vol. 31, pp. 2724 -2737, 2023. [[github]](https://github.com/sp-uhh/storm)
|
41 |
+
- Bunlong Lay, Simon Welker, Julius Richter, Timo Gerkmann, [*"Reducing the Prior Mismatch of Stochastic Differential Equations for Diffusion-based Speech Enhancement"*](https://www.isca-archive.org/interspeech_2023/lay23_interspeech.html), ISCA Interspeech, Dublin, Ireland, Aug. 2023. [[github]](https://github.com/sp-uhh/sgmse-bbed)
|
42 |
+
|
43 |
+
## Installation
|
44 |
+
|
45 |
+
- Create a new virtual environment with Python 3.11 (we have not tested other Python versions, but they may work).
|
46 |
+
- Install the package dependencies via `pip install -r requirements.txt`.
|
47 |
+
- Let pip resolve the dependencies for you. If you encounter any issues, please check `requirements_version.txt` for the exact versions we used.
|
48 |
+
- If using W&B logging (default):
|
49 |
+
- Set up a [wandb.ai](https://wandb.ai/) account
|
50 |
+
- Log in via `wandb login` before running our code.
|
51 |
+
- If not using W&B logging:
|
52 |
+
- Pass the option `--nolog` to `train.py`.
|
53 |
+
- Your logs will be stored as local CSVLogger logs in `lightning_logs/`.
|
54 |
+
|
55 |
+
## Pretrained checkpoints
|
56 |
+
|
57 |
+
- For the speech enhancement task, we offer pretrained checkpoints for models that have been trained on the VoiceBank-DEMAND and WSJ0-CHiME3 datasets, as described in our journal paper [2]. You can download them [here](https://drive.google.com/drive/folders/1CSnkhUSoiv3RG0xg7WEcVapyLuwDaLbe?usp=sharing).
|
58 |
+
- SGMSE+ trained on VoiceBank-DEMAND: `gdown 1_H3EXvhcYBhOZ9QNUcD5VZHc6ktrRbwQ`
|
59 |
+
- SGMSE+ trained on WSJ0-CHiME3: `gdown 16K4DUdpmLhDNC7pJhBBc08pkSIn_yMPi`
|
60 |
+
- For the dereverberation task, we offer a checkpoint trained on our WSJ0-REVERB dataset. You can download it [here](https://drive.google.com/drive/folders/1082_PSEgrqoVVrNsAkSIcpLF1AAtzGwV?usp=sharing).
|
61 |
+
- SGMSE+ trained on WSJ0-REVERB: `gdown 1eiOy0VjHh9V9ZUFTxu1Pq2w19izl9ejD`
|
62 |
+
- Note that this checkpoint works better with sampler settings `--N 50 --snr 0.33`.
|
63 |
+
- For 48 kHz models [3], we offer pretrained checkpoints for speech enhancement, trained on the EARS-WHAM dataset, and for dereverberation, trained on the EARS-Reverb dataset. You can download them [here](https://drive.google.com/drive/folders/1Tn6pVwjxUAy1DJ8167JCg3enuSi0hiw5?usp=sharing).
|
64 |
+
- SGMSE+ trained on EARS-WHAM: `gdown 1t_DLLk8iPH6nj8M5wGeOP3jFPaz3i7K5`
|
65 |
+
- SGMSE+ trained on EARS-Reverb: `gdown 1PunXuLbuyGkknQCn_y-RCV2dTZBhyE3V`
|
66 |
+
|
67 |
+
Usage:
|
68 |
+
- For resuming training, you can use the `--ckpt` option of `train.py`.
|
69 |
+
- For evaluating these checkpoints, use the `--ckpt` option of `enhancement.py` (see section **Evaluation** below).
|
70 |
+
|
71 |
+
## Training
|
72 |
+
|
73 |
+
Training is done by executing `train.py`. A minimal running example with default settings (as in our paper [2]) can be run with
|
74 |
+
|
75 |
+
```bash
|
76 |
+
python train.py --base_dir <your_base_dir>
|
77 |
+
```
|
78 |
+
|
79 |
+
where `your_base_dir` should be a path to a folder containing subdirectories `train/` and `valid/` (optionally `test/` as well). Each subdirectory must itself have two subdirectories `clean/` and `noisy/`, with the same filenames present in both. We currently only support training with `.wav` files.
|
80 |
+
|
81 |
+
To see all available training options, run `python train.py --help`. Note that the available options for the SDE and the backbone network change depending on which SDE and backbone you use. These can be set through the `--sde` and `--backbone` options.
|
82 |
+
|
83 |
+
**Note:**
|
84 |
+
- Our journal preprint [2] uses `--backbone ncsnpp`.
|
85 |
+
- For the 48 kHz model [3], use `--backbone ncsnpp_48k --n_fft 1534 --hop_length 384 --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
|
86 |
+
- Our Interspeech paper [1] uses `--backbone dcunet`. You need to pass `--n_fft 512` to make it work.
|
87 |
+
- Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (`--spec_factor 0.15` rather than `--spec_factor 0.333`), but we've found the value in this repository to generally perform better for both models [1] and [2].
|
88 |
+
|
89 |
+
## Evaluation
|
90 |
+
|
91 |
+
To evaluate on a test set, run
|
92 |
+
```bash
|
93 |
+
python enhancement.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir> --ckpt <path_to_model_checkpoint>
|
94 |
+
```
|
95 |
+
|
96 |
+
to generate the enhanced .wav files, and subsequently run
|
97 |
+
|
98 |
+
```bash
|
99 |
+
python calc_metrics.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir>
|
100 |
+
```
|
101 |
+
|
102 |
+
to calculate and output the instrumental metrics.
|
103 |
+
|
104 |
+
Both scripts should receive the same `--test_dir` and `--enhanced_dir` parameters. The `--cpkt` parameter of `enhancement.py` should be the path to a trained model checkpoint, as stored by the logger in `logs/`.
|
105 |
+
|
106 |
+
## Citations / References
|
107 |
+
|
108 |
+
We kindly ask you to cite our papers in your publication when using any of our research or code:
|
109 |
+
```bib
|
110 |
+
@inproceedings{welker22speech,
|
111 |
+
author={Simon Welker and Julius Richter and Timo Gerkmann},
|
112 |
+
title={Speech Enhancement with Score-Based Generative Models in the Complex {STFT} Domain},
|
113 |
+
year={2022},
|
114 |
+
booktitle={Proc. Interspeech 2022},
|
115 |
+
pages={2928--2932},
|
116 |
+
doi={10.21437/Interspeech.2022-10653}
|
117 |
+
}
|
118 |
+
```
|
119 |
+
```bib
|
120 |
+
@article{richter2023speech,
|
121 |
+
title={Speech Enhancement and Dereverberation with Diffusion-based Generative Models},
|
122 |
+
author={Richter, Julius and Welker, Simon and Lemercier, Jean-Marie and Lay, Bunlong and Gerkmann, Timo},
|
123 |
+
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
|
124 |
+
volume={31},
|
125 |
+
pages={2351-2364},
|
126 |
+
year={2023},
|
127 |
+
doi={10.1109/TASLP.2023.3285241}
|
128 |
+
}
|
129 |
+
```
|
130 |
+
```bib
|
131 |
+
@inproceedings{richter2024ears,
|
132 |
+
title={{EARS}: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation},
|
133 |
+
author={Richter, Julius and Wu, Yi-Chiao and Krenn, Steven and Welker, Simon and Lay, Bunlong and Watanabe, Shinjii and Richard, Alexander and Gerkmann, Timo},
|
134 |
+
booktitle={ISCA Interspeech},
|
135 |
+
year={2024}
|
136 |
+
}
|
137 |
+
```
|
138 |
+
|
139 |
+
>[1] Simon Welker, Julius Richter, Timo Gerkmann. "Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain", ISCA Interspeech, Incheon, Korea, Sep. 2022.
|
140 |
+
>
|
141 |
+
>[2] Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann. "Speech Enhancement and Dereverberation with Diffusion-Based Generative Models", IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023.
|
142 |
+
>
|
143 |
+
>[3] Julius Richter, Yi-Chiao Wu, Steven Krenn, Simon Welker, Bunlong Lay, Shinji Watanabe, Alexander Richard, Timo Gerkmann. "EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation", ISCA Interspeech, Kos, Greece, 2024.
|
calc_metrics.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import join
|
2 |
+
from glob import glob
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from soundfile import read
|
5 |
+
from tqdm import tqdm
|
6 |
+
from pesq import pesq
|
7 |
+
import pandas as pd
|
8 |
+
import librosa
|
9 |
+
|
10 |
+
from pystoi import stoi
|
11 |
+
|
12 |
+
from sgmse.util.other import energy_ratios, mean_std
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
parser = ArgumentParser()
|
17 |
+
parser.add_argument("--clean_dir", type=str, required=True, help='Directory containing the clean data')
|
18 |
+
parser.add_argument("--noisy_dir", type=str, required=True, help='Directory containing the noisy data')
|
19 |
+
parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
data = {"filename": [], "pesq": [], "estoi": [], "si_sdr": [], "si_sir": [], "si_sar": []}
|
23 |
+
|
24 |
+
# Evaluate standard metrics
|
25 |
+
noisy_files = []
|
26 |
+
noisy_files += sorted(glob(join(args.noisy_dir, '*.wav')))
|
27 |
+
noisy_files += sorted(glob(join(args.noisy_dir, '**', '*.wav')))
|
28 |
+
for noisy_file in tqdm(noisy_files):
|
29 |
+
filename = noisy_file.replace(args.noisy_dir, "")[1:]
|
30 |
+
if 'dB' in filename:
|
31 |
+
clean_filename = filename.split("_")[0] + ".wav"
|
32 |
+
else:
|
33 |
+
clean_filename = filename
|
34 |
+
x, sr_x = read(join(args.clean_dir, clean_filename))
|
35 |
+
y, sr_y = read(join(args.noisy_dir, filename))
|
36 |
+
x_hat, sr_x_hat = read(join(args.enhanced_dir, filename))
|
37 |
+
assert sr_x == sr_y == sr_x_hat
|
38 |
+
n = y - x
|
39 |
+
x_hat_16k = librosa.resample(x_hat, orig_sr=sr_x_hat, target_sr=16000) if sr_x_hat != 16000 else x_hat
|
40 |
+
x_16k = librosa.resample(x, orig_sr=sr_x, target_sr=16000) if sr_x != 16000 else x
|
41 |
+
data["filename"].append(filename)
|
42 |
+
data["pesq"].append(pesq(16000, x_16k, x_hat_16k, 'wb'))
|
43 |
+
data["estoi"].append(stoi(x, x_hat, sr_x, extended=True))
|
44 |
+
data["si_sdr"].append(energy_ratios(x_hat, x, n)[0])
|
45 |
+
data["si_sir"].append(energy_ratios(x_hat, x, n)[1])
|
46 |
+
data["si_sar"].append(energy_ratios(x_hat, x, n)[2])
|
47 |
+
|
48 |
+
# Save results as DataFrame
|
49 |
+
df = pd.DataFrame(data)
|
50 |
+
|
51 |
+
# Print results
|
52 |
+
print("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())))
|
53 |
+
print("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())))
|
54 |
+
print("SI-SDR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sdr"].to_numpy())))
|
55 |
+
print("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())))
|
56 |
+
print("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())))
|
57 |
+
|
58 |
+
# Save average results to file
|
59 |
+
log = open(join(args.enhanced_dir, "_avg_results.txt"), "w")
|
60 |
+
log.write("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())) + "\n")
|
61 |
+
log.write("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())) + "\n")
|
62 |
+
log.write("SI-SDR: {:.1f} ± {:.2f}".format(*mean_std(df["si_sdr"].to_numpy())) + "\n")
|
63 |
+
log.write("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())) + "\n")
|
64 |
+
log.write("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())) + "\n")
|
65 |
+
|
66 |
+
# Save DataFrame as csv file
|
67 |
+
df.to_csv(join(args.enhanced_dir, "_results.csv"), index=False)
|
diffusion_process.png
ADDED
enhancement.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from os import makedirs
|
5 |
+
from soundfile import write
|
6 |
+
from torchaudio import load
|
7 |
+
from os.path import join, dirname
|
8 |
+
from argparse import ArgumentParser
|
9 |
+
from librosa import resample
|
10 |
+
|
11 |
+
# Set CUDA architecture list
|
12 |
+
from sgmse.util.other import set_torch_cuda_arch_list
|
13 |
+
set_torch_cuda_arch_list()
|
14 |
+
|
15 |
+
from sgmse.model import ScoreModel
|
16 |
+
from sgmse.util.other import pad_spec
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
parser = ArgumentParser()
|
21 |
+
parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the test data')
|
22 |
+
parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
|
23 |
+
parser.add_argument("--ckpt", type=str, help='Path to model checkpoint')
|
24 |
+
parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.")
|
25 |
+
parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps")
|
26 |
+
parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics")
|
27 |
+
parser.add_argument("--N", type=int, default=30, help="Number of reverse steps")
|
28 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference")
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
# Load score model
|
32 |
+
model = ScoreModel.load_from_checkpoint(args.ckpt, map_location=args.device)
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
# Get list of noisy files
|
36 |
+
noisy_files = []
|
37 |
+
noisy_files += sorted(glob.glob(join(args.test_dir, '*.wav')))
|
38 |
+
noisy_files += sorted(glob.glob(join(args.test_dir, '**', '*.wav')))
|
39 |
+
|
40 |
+
# Check if the model is trained on 48 kHz data
|
41 |
+
if model.backbone == 'ncsnpp_48k':
|
42 |
+
target_sr = 48000
|
43 |
+
pad_mode = "reflection"
|
44 |
+
else:
|
45 |
+
target_sr = 16000
|
46 |
+
pad_mode = "zero_pad"
|
47 |
+
|
48 |
+
# Enhance files
|
49 |
+
for noisy_file in tqdm(noisy_files):
|
50 |
+
filename = noisy_file.replace(args.test_dir, "")
|
51 |
+
filename = filename[1:] if filename.startswith("/") else filename
|
52 |
+
|
53 |
+
# Load wav
|
54 |
+
y, sr = load(noisy_file)
|
55 |
+
|
56 |
+
# Resample if necessary
|
57 |
+
if sr != target_sr:
|
58 |
+
y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
|
59 |
+
|
60 |
+
T_orig = y.size(1)
|
61 |
+
|
62 |
+
# Normalize
|
63 |
+
norm_factor = y.abs().max()
|
64 |
+
y = y / norm_factor
|
65 |
+
|
66 |
+
# Prepare DNN input
|
67 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
|
68 |
+
Y = pad_spec(Y, mode=pad_mode)
|
69 |
+
|
70 |
+
# Reverse sampling
|
71 |
+
sampler = model.get_pc_sampler(
|
72 |
+
'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N,
|
73 |
+
corrector_steps=args.corrector_steps, snr=args.snr)
|
74 |
+
sample, _ = sampler()
|
75 |
+
|
76 |
+
# Backward transform in time domain
|
77 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
78 |
+
|
79 |
+
# Renormalize
|
80 |
+
x_hat = x_hat * norm_factor
|
81 |
+
|
82 |
+
# Write enhanced wav file
|
83 |
+
makedirs(dirname(join(args.enhanced_dir, filename)), exist_ok=True)
|
84 |
+
write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), target_sr)
|
logs/.keep
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
preprocessing/create_wsj0_chime3.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
from librosa import load
|
4 |
+
from librosa.core import resample
|
5 |
+
import argparse
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pathlib import Path
|
8 |
+
import numpy as np
|
9 |
+
from soundfile import write
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
# Python script for generating noisy mixtures for training
|
14 |
+
#
|
15 |
+
# Mix WSJ0 with CHiME3 noise with SNR sampled uniformly in [min_snr, max_snr]
|
16 |
+
|
17 |
+
|
18 |
+
min_snr = 0
|
19 |
+
max_snr = 20
|
20 |
+
sr = 16000
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
parser = ArgumentParser()
|
25 |
+
parser.add_argument("wsj0", type=str, help='path to WSJ0 directory')
|
26 |
+
parser.add_argument("chime3", type=str, help='path to CHiME3 directory')
|
27 |
+
parser.add_argument("target", type=str, help='target path for training files')
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
# Clean speech for training
|
31 |
+
train_speech_files = sorted(glob(args.wsj0 + '**/si_tr_s/**/*.wav', recursive=True))
|
32 |
+
valid_speech_files = sorted(glob(args.wsj0 + '**/si_dt_05/**/*.wav', recursive=True))
|
33 |
+
test_speech_files = sorted(glob(args.wsj0 + '**/si_et_05/**/*.wav', recursive=True))
|
34 |
+
|
35 |
+
noise_files = glob(args.chime3 + '**/backgrounds/*.wav', recursive=True)
|
36 |
+
noise_files = [file for file in noise_files if (file[-7:-4] == "CH1")]
|
37 |
+
|
38 |
+
# Load CHiME3 noise files
|
39 |
+
noises = []
|
40 |
+
print('Loading CHiME3 noise files')
|
41 |
+
for file in noise_files:
|
42 |
+
noise = load(file, sr=None)[0]
|
43 |
+
noises.append(noise)
|
44 |
+
|
45 |
+
# Create target dir
|
46 |
+
train_clean_path = Path(os.path.join(args.target, 'train/clean'))
|
47 |
+
train_noisy_path = Path(os.path.join(args.target, 'train/noisy'))
|
48 |
+
valid_clean_path = Path(os.path.join(args.target, 'valid/clean'))
|
49 |
+
valid_noisy_path = Path(os.path.join(args.target, 'valid/noisy'))
|
50 |
+
test_clean_path = Path(os.path.join(args.target, 'test/clean'))
|
51 |
+
test_noisy_path = Path(os.path.join(args.target, 'test/noisy'))
|
52 |
+
|
53 |
+
train_clean_path.mkdir(parents=True, exist_ok=True)
|
54 |
+
train_noisy_path.mkdir(parents=True, exist_ok=True)
|
55 |
+
valid_clean_path.mkdir(parents=True, exist_ok=True)
|
56 |
+
valid_noisy_path.mkdir(parents=True, exist_ok=True)
|
57 |
+
test_clean_path.mkdir(parents=True, exist_ok=True)
|
58 |
+
test_noisy_path.mkdir(parents=True, exist_ok=True)
|
59 |
+
|
60 |
+
# Initialize seed for reproducability
|
61 |
+
np.random.seed(0)
|
62 |
+
|
63 |
+
# Create files for training
|
64 |
+
print('Create training files')
|
65 |
+
for i, speech_file in enumerate(tqdm(train_speech_files)):
|
66 |
+
s, _ = load(speech_file, sr=sr)
|
67 |
+
|
68 |
+
snr_dB = np.random.uniform(min_snr, max_snr)
|
69 |
+
noise_ind = np.random.randint(len(noises))
|
70 |
+
speech_power = 1/len(s)*np.sum(s**2)
|
71 |
+
|
72 |
+
n = noises[noise_ind]
|
73 |
+
start = np.random.randint(len(n)-len(s))
|
74 |
+
n = n[start:start+len(s)]
|
75 |
+
|
76 |
+
noise_power = 1/len(n)*np.sum(n**2)
|
77 |
+
noise_power_target = speech_power*np.power(10,-snr_dB/10)
|
78 |
+
k = noise_power_target / noise_power
|
79 |
+
n = n * np.sqrt(k)
|
80 |
+
x = s + n
|
81 |
+
|
82 |
+
file_name = speech_file.split('/')[-1]
|
83 |
+
write(os.path.join(train_clean_path, file_name), s, sr)
|
84 |
+
write(os.path.join(train_noisy_path, file_name), x, sr)
|
85 |
+
|
86 |
+
# Create files for validation
|
87 |
+
print('Create validation files')
|
88 |
+
for i, speech_file in enumerate(tqdm(valid_speech_files)):
|
89 |
+
s, _ = load(speech_file, sr=sr)
|
90 |
+
|
91 |
+
snr_dB = np.random.uniform(min_snr, max_snr)
|
92 |
+
noise_ind = np.random.randint(len(noises))
|
93 |
+
speech_power = 1/len(s)*np.sum(s**2)
|
94 |
+
|
95 |
+
n = noises[noise_ind]
|
96 |
+
start = np.random.randint(len(n)-len(s))
|
97 |
+
n = n[start:start+len(s)]
|
98 |
+
|
99 |
+
noise_power = 1/len(n)*np.sum(n**2)
|
100 |
+
noise_power_target = speech_power*np.power(10,-snr_dB/10)
|
101 |
+
k = noise_power_target / noise_power
|
102 |
+
n = n * np.sqrt(k)
|
103 |
+
x = s + n
|
104 |
+
|
105 |
+
file_name = speech_file.split('/')[-1]
|
106 |
+
write(os.path.join(valid_clean_path, file_name), s, sr)
|
107 |
+
write(os.path.join(valid_noisy_path, file_name), x, sr)
|
108 |
+
|
109 |
+
# Create files for test
|
110 |
+
print('Create test files')
|
111 |
+
for i, speech_file in enumerate(tqdm(test_speech_files)):
|
112 |
+
s, _ = load(speech_file, sr=sr)
|
113 |
+
|
114 |
+
snr_dB = np.random.uniform(min_snr, max_snr)
|
115 |
+
noise_ind = np.random.randint(len(noises))
|
116 |
+
speech_power = 1/len(s)*np.sum(s**2)
|
117 |
+
|
118 |
+
n = noises[noise_ind]
|
119 |
+
start = np.random.randint(len(n)-len(s))
|
120 |
+
n = n[start:start+len(s)]
|
121 |
+
|
122 |
+
noise_power = 1/len(n)*np.sum(n**2)
|
123 |
+
noise_power_target = speech_power*np.power(10,-snr_dB/10)
|
124 |
+
k = noise_power_target / noise_power
|
125 |
+
n = n * np.sqrt(k)
|
126 |
+
x = s + n
|
127 |
+
|
128 |
+
file_name = speech_file.split('/')[-1]
|
129 |
+
write(os.path.join(test_clean_path, file_name), s, sr)
|
130 |
+
write(os.path.join(test_noisy_path, file_name), x, sr)
|
preprocessing/create_wsj0_qut.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
from librosa import load
|
4 |
+
from librosa.core import resample
|
5 |
+
import argparse
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pathlib import Path
|
8 |
+
import numpy as np
|
9 |
+
from soundfile import write
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
# Python script for generating noisy mixtures for training
|
14 |
+
#
|
15 |
+
# Mix WSJ0 with QUT noise with SNR sampled uniformly in [min_snr, max_snr]
|
16 |
+
|
17 |
+
|
18 |
+
min_snr = 0
|
19 |
+
max_snr = 15
|
20 |
+
sr = 16000
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
parser = ArgumentParser()
|
25 |
+
parser.add_argument("wsj0", type=str, help='path to WSJ0 directory')
|
26 |
+
parser.add_argument("qut", type=str, help='path to QUT directory')
|
27 |
+
parser.add_argument("target", type=str, help='target path for training files')
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
# Clean speech for training
|
31 |
+
train_speech_files = sorted(glob(args.wsj0 + '**/si_tr_s/**/*.wav', recursive=True))
|
32 |
+
valid_speech_files = sorted(glob(args.wsj0 + '**/si_dt_05/**/*.wav', recursive=True))
|
33 |
+
test_speech_files = sorted(glob(args.wsj0 + '**/si_et_05/**/*.wav', recursive=True))
|
34 |
+
|
35 |
+
# Load QUT noise files
|
36 |
+
print('Loading QUT noise files')
|
37 |
+
cafe, sr_QUT = load(glob(args.qut + '**/CAFE-CAFE-1.wav', recursive=True)[0], sr=None)
|
38 |
+
car, sr_QUT = load(glob(args.qut + '**/CAR-WINDOWNB-1.wav', recursive=True)[0], sr=None)
|
39 |
+
home, sr_QUT = load(glob(args.qut + '**/HOME-KITCHEN-1.wav', recursive=True)[0], sr=None)
|
40 |
+
street, sr_QUT = load(glob(args.qut + '**/STREET-CITY-1.wav', recursive=True)[0], sr=None)
|
41 |
+
|
42 |
+
print('Resampling QUT noise files to 16kHz')
|
43 |
+
cafe = resample(cafe, sr_QUT, sr)
|
44 |
+
car = resample(car, sr_QUT, sr)
|
45 |
+
home = resample(home, sr_QUT, sr)
|
46 |
+
street = resample(street, sr_QUT, sr)
|
47 |
+
|
48 |
+
# ToDo: resampling with ffmpeg bacause librosa is soooo slow
|
49 |
+
# cafe, fs_QUT = load(os.path.join(args.qut, 'CAFE-CAFE-1_16k.wav'), sr=None)
|
50 |
+
# car, fs_QUT = load(os.path.join(args.qut, 'CAR-WINDOWNB-1_16k.wav'), sr=None)
|
51 |
+
# home, fs_QUT = load(os.path.join(args.qut, 'HOME-KITCHEN-1_16k.wav'), sr=None)
|
52 |
+
# street, fs_QUT = load(os.path.join(args.qut, 'STREET-CITY-1_16k.wav'), sr=None)
|
53 |
+
|
54 |
+
# Remove sweeps in the first and last 2 min in car noise file
|
55 |
+
car = car[120*sr:-120*sr]
|
56 |
+
|
57 |
+
# Create target dir
|
58 |
+
train_clean_path = Path(os.path.join(args.target, 'train/clean'))
|
59 |
+
train_noisy_path = Path(os.path.join(args.target, 'train/noisy'))
|
60 |
+
valid_clean_path = Path(os.path.join(args.target, 'valid/clean'))
|
61 |
+
valid_noisy_path = Path(os.path.join(args.target, 'valid/noisy'))
|
62 |
+
test_clean_path = Path(os.path.join(args.target, 'test/clean'))
|
63 |
+
test_noisy_path = Path(os.path.join(args.target, 'test/noisy'))
|
64 |
+
|
65 |
+
train_clean_path.mkdir(parents=True, exist_ok=True)
|
66 |
+
train_noisy_path.mkdir(parents=True, exist_ok=True)
|
67 |
+
valid_clean_path.mkdir(parents=True, exist_ok=True)
|
68 |
+
valid_noisy_path.mkdir(parents=True, exist_ok=True)
|
69 |
+
test_clean_path.mkdir(parents=True, exist_ok=True)
|
70 |
+
test_noisy_path.mkdir(parents=True, exist_ok=True)
|
71 |
+
|
72 |
+
# Initialize seed for reproducability
|
73 |
+
np.random.seed(0)
|
74 |
+
|
75 |
+
# Create files for training
|
76 |
+
print('Create training files')
|
77 |
+
for i, speech_file in enumerate(tqdm(train_speech_files)):
|
78 |
+
s, _ = load(speech_file, sr=sr)
|
79 |
+
|
80 |
+
snr_dB = np.random.uniform(min_snr, max_snr)
|
81 |
+
noise_type = np.random.randint(4)
|
82 |
+
speech_power = 1/len(s)*np.sum(s**2)
|
83 |
+
|
84 |
+
if noise_type == 0:
|
85 |
+
start = np.random.randint(len(cafe)-len(s))
|
86 |
+
n = cafe[start:start+len(s)]
|
87 |
+
elif noise_type == 1:
|
88 |
+
start = np.random.randint(len(home)-len(s))
|
89 |
+
n = home[start:start+len(s)]
|
90 |
+
elif noise_type == 2:
|
91 |
+
start = np.random.randint(len(street)-len(s))
|
92 |
+
n = street[start:start+len(s)]
|
93 |
+
elif noise_type == 3:
|
94 |
+
start = np.random.randint(len(car)-len(s))
|
95 |
+
n = car[start:start+len(s)]
|
96 |
+
else:
|
97 |
+
raise ValueError('Unexpected noise type index')
|
98 |
+
noise_power = 1/len(n)*np.sum(n**2)
|
99 |
+
noise_power_target = speech_power*np.power(10,-snr_dB/10)
|
100 |
+
k = noise_power_target / noise_power
|
101 |
+
n = n * np.sqrt(k)
|
102 |
+
x = s + n
|
103 |
+
|
104 |
+
file_name = speech_file.split('/')[-1]
|
105 |
+
write(os.path.join(train_clean_path, file_name), s, sr)
|
106 |
+
write(os.path.join(train_noisy_path, file_name), x, sr)
|
107 |
+
|
108 |
+
# Create files for validation
|
109 |
+
print('Create validation files')
|
110 |
+
for i, speech_file in enumerate(tqdm(valid_speech_files)):
|
111 |
+
s, _ = load(speech_file, sr=sr)
|
112 |
+
|
113 |
+
snr_dB = np.random.uniform(min_snr, max_snr)
|
114 |
+
noise_type = np.random.randint(4)
|
115 |
+
speech_power = 1/len(s)*np.sum(s**2)
|
116 |
+
|
117 |
+
if noise_type == 0:
|
118 |
+
start = np.random.randint(len(cafe)-len(s))
|
119 |
+
n = cafe[start:start+len(s)]
|
120 |
+
elif noise_type == 1:
|
121 |
+
start = np.random.randint(len(home)-len(s))
|
122 |
+
n = home[start:start+len(s)]
|
123 |
+
elif noise_type == 2:
|
124 |
+
start = np.random.randint(len(street)-len(s))
|
125 |
+
n = street[start:start+len(s)]
|
126 |
+
elif noise_type == 3:
|
127 |
+
start = np.random.randint(len(car)-len(s))
|
128 |
+
n = car[start:start+len(s)]
|
129 |
+
else:
|
130 |
+
raise ValueError('Unexpected noise type index')
|
131 |
+
noise_power = 1/len(n)*np.sum(n**2)
|
132 |
+
noise_power_target = speech_power*np.power(10,-snr_dB/10)
|
133 |
+
k = noise_power_target / noise_power
|
134 |
+
n = n * np.sqrt(k)
|
135 |
+
x = s + n
|
136 |
+
|
137 |
+
file_name = speech_file.split('/')[-1]
|
138 |
+
write(os.path.join(valid_clean_path, file_name), s, sr)
|
139 |
+
write(os.path.join(valid_noisy_path, file_name), x, sr)
|
140 |
+
|
141 |
+
# Create files for test
|
142 |
+
print('Create test files')
|
143 |
+
for i, speech_file in enumerate(tqdm(test_speech_files)):
|
144 |
+
s, _ = load(speech_file, sr=sr)
|
145 |
+
|
146 |
+
snr_dB = np.random.uniform(min_snr, max_snr)
|
147 |
+
noise_type = np.random.randint(4)
|
148 |
+
speech_power = 1/len(s)*np.sum(s**2)
|
149 |
+
|
150 |
+
if noise_type == 0:
|
151 |
+
start = np.random.randint(len(cafe)-len(s))
|
152 |
+
n = cafe[start:start+len(s)]
|
153 |
+
elif noise_type == 1:
|
154 |
+
start = np.random.randint(len(home)-len(s))
|
155 |
+
n = home[start:start+len(s)]
|
156 |
+
elif noise_type == 2:
|
157 |
+
start = np.random.randint(len(street)-len(s))
|
158 |
+
n = street[start:start+len(s)]
|
159 |
+
elif noise_type == 3:
|
160 |
+
start = np.random.randint(len(car)-len(s))
|
161 |
+
n = car[start:start+len(s)]
|
162 |
+
else:
|
163 |
+
raise ValueError('Unexpected noise type index')
|
164 |
+
noise_power = 1/len(n)*np.sum(n**2)
|
165 |
+
noise_power_target = speech_power*np.power(10,-snr_dB/10)
|
166 |
+
k = noise_power_target / noise_power
|
167 |
+
n = n * np.sqrt(k)
|
168 |
+
x = s + n
|
169 |
+
|
170 |
+
file_name = speech_file.split('/')[-1]
|
171 |
+
write(os.path.join(test_clean_path, file_name), s, sr)
|
172 |
+
write(os.path.join(test_noisy_path, file_name), x, sr)
|
preprocessing/create_wsj0_reverb.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import soundfile as sf
|
6 |
+
import pyroomacoustics as pra
|
7 |
+
from glob import glob
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
SEED = 100
|
12 |
+
np.random.seed(SEED)
|
13 |
+
|
14 |
+
T60_RANGE = [0.4, 1.0]
|
15 |
+
SNR_RANGE = [0, 20]
|
16 |
+
DIM_RANGE = [5, 15, 5, 15, 2, 6]
|
17 |
+
MIN_DISTANCE_TO_WALL = 1
|
18 |
+
MIC_ARRAY_RADIUS = 0.16
|
19 |
+
TARGET_T60_SHAPE = {"CI": 0.08, "HA": 0.2}
|
20 |
+
TARGET_T60_SHAPE = {"CI": 0.10, "HA": 0.2}
|
21 |
+
TARGETS_CROP = {"CI": 16e-3, "HA": 40e-3}
|
22 |
+
NB_SAMPLES_PER_ROOM = 1
|
23 |
+
CHANNELS = 1
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument('--wsj0_dir', type=str, required=True, help='Path to the WSJ0 directory which should contain subdirectories "si_dt_05", "si_tr_s" and "si_et_05".')
|
28 |
+
parser.add_argument('--target_dir', type=str, required=True, help='Path to the target directory for saving WSJ0-REVERB.')
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
def obtain_clean_file(speech_list, i_sample, sample_rate=16000):
|
32 |
+
speech, speech_sr = sf.read(speech_list[i_sample])
|
33 |
+
speech_basename = os.path.basename(speech_list[i_sample])
|
34 |
+
assert speech_sr == sample_rate, f"wrong speech sampling rate here: expected {sample_rate} got {speech_sr}"
|
35 |
+
return speech.squeeze(), speech_sr, speech_basename[: -4]
|
36 |
+
|
37 |
+
splits = ['valid', 'train', 'test']
|
38 |
+
dic_split = {"valid": "si_dt_05", "train": "si_tr_s", "test": "si_et_05"}
|
39 |
+
speech_lists = {split:sorted(glob(f"{os.path.join(args.wsj0_dir, dic_split[split])}/**/*.wav")) for split in splits}
|
40 |
+
sample_rate = 16000
|
41 |
+
output_dir = args.target_dir
|
42 |
+
|
43 |
+
if os.path.exists(output_dir):
|
44 |
+
shutil.rmtree(output_dir)
|
45 |
+
|
46 |
+
for i_split, split in enumerate(splits):
|
47 |
+
print("Processing split n° {}: {}...".format(i_split+1, split))
|
48 |
+
|
49 |
+
reverberant_output_dir = os.path.join(output_dir, "audio", split, "reverb")
|
50 |
+
dry_output_dir = os.path.join(output_dir, "audio", split, "anechoic")
|
51 |
+
noisy_reverberant_output_dir = os.path.join(output_dir, "audio", split, "noisy_reverb")
|
52 |
+
if split == "test":
|
53 |
+
unauralized_output_dir = os.path.join(output_dir, "audio", split, "unauralized")
|
54 |
+
|
55 |
+
os.makedirs(reverberant_output_dir, exist_ok=True)
|
56 |
+
os.makedirs(dry_output_dir, exist_ok=True)
|
57 |
+
if split == "test":
|
58 |
+
os.makedirs(unauralized_output_dir, exist_ok=True)
|
59 |
+
|
60 |
+
speech_list = speech_lists[split]
|
61 |
+
speech_dir = None
|
62 |
+
real_nb_samples = len(speech_list)
|
63 |
+
|
64 |
+
for i_sample in tqdm(range(real_nb_samples)):
|
65 |
+
if not i_sample % NB_SAMPLES_PER_ROOM: #Generate new room
|
66 |
+
t60 = np.random.uniform(T60_RANGE[0], T60_RANGE[1]) #Draw T60
|
67 |
+
room_dim = np.array([ np.random.uniform(DIM_RANGE[2*n], DIM_RANGE[2*n+1]) for n in range(3) ]) #Draw Dimensions
|
68 |
+
center_mic_position = np.array([ np.random.uniform(MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL) for n in range(3) ]) #draw source position
|
69 |
+
source_position = np.array([ np.random.uniform(MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL) for n in range(3) ]) #draw source position
|
70 |
+
mic_array_2d = pra.beamforming.circular_2D_array(center_mic_position[: -1], CHANNELS, phi0=0, radius=MIC_ARRAY_RADIUS) # Compute microphone array
|
71 |
+
mic_array = np.pad(mic_array_2d, ((0, 1), (0, 0)), mode="constant", constant_values=center_mic_position[-1])
|
72 |
+
|
73 |
+
### Reverberant Room
|
74 |
+
e_absorption, max_order = pra.inverse_sabine(t60, room_dim) #Compute absorption coeff
|
75 |
+
reverberant_room = pra.ShoeBox(
|
76 |
+
room_dim, fs=16000, materials=pra.Material(e_absorption), max_order=min(3, max_order)
|
77 |
+
) # Create room
|
78 |
+
reverberant_room.set_ray_tracing()
|
79 |
+
reverberant_room.add_microphone_array(mic_array) # Add microphone array
|
80 |
+
|
81 |
+
# Pick unauralized files
|
82 |
+
speech, speech_sr, speech_basename = obtain_clean_file(speech_list, i_sample, sample_rate=sample_rate)
|
83 |
+
|
84 |
+
# Generate reverberant room
|
85 |
+
reverberant_room.add_source(source_position, signal=speech)
|
86 |
+
reverberant_room.compute_rir()
|
87 |
+
reverberant_room.simulate()
|
88 |
+
t60_real = np.mean(reverberant_room.measure_rt60()).squeeze()
|
89 |
+
reverberant = np.stack(reverberant_room.mic_array.signals).swapaxes(0, 1)
|
90 |
+
|
91 |
+
e_absorption_dry = 0.99 #For Neural Networks OK but clearly not for WPE
|
92 |
+
dry_room = pra.ShoeBox(
|
93 |
+
room_dim, fs=16000, materials=pra.Material(e_absorption_dry), max_order=0
|
94 |
+
) # Create room
|
95 |
+
dry_room.add_microphone_array(mic_array) # Add microphone array
|
96 |
+
|
97 |
+
# Generate dry room
|
98 |
+
dry_room.add_source(source_position, signal=speech)
|
99 |
+
dry_room.compute_rir()
|
100 |
+
dry_room.simulate()
|
101 |
+
t60_real_dry = np.mean(dry_room.measure_rt60()).squeeze()
|
102 |
+
rir_dry = dry_room.rir
|
103 |
+
dry = np.stack(dry_room.mic_array.signals).swapaxes(0, 1)
|
104 |
+
dry = np.pad(dry, ((0, int(.5*sample_rate)), (0, 0)), mode="constant", constant_values=0) #Add 1 second of silence after dry (because very dry) so that the reverb is not cut, and all samples have same length
|
105 |
+
|
106 |
+
min_len_sample = min(reverberant.shape[0], dry.shape[0])
|
107 |
+
dry = dry[: min_len_sample]
|
108 |
+
reverberant = reverberant[: min_len_sample]
|
109 |
+
output_scaling = np.max(reverberant) / .9
|
110 |
+
|
111 |
+
drr = 10*np.log10( np.mean(dry**2) / (np.mean(reverberant**2) + 1e-8) + 1e-8 )
|
112 |
+
output_filename = f"{speech_basename}_{i_sample//NB_SAMPLES_PER_ROOM}_{t60_real:.2f}_{drr:.1f}.wav"
|
113 |
+
|
114 |
+
sf.write(os.path.join(dry_output_dir, output_filename), 1/output_scaling*dry, samplerate=sample_rate)
|
115 |
+
sf.write(os.path.join(reverberant_output_dir, output_filename), 1/output_scaling*reverberant, samplerate=sample_rate)
|
116 |
+
|
117 |
+
if split == "test":
|
118 |
+
sf.write(os.path.join(unauralized_output_dir, output_filename), speech, samplerate=sample_rate)
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gdown
|
2 |
+
h5py
|
3 |
+
ipympl
|
4 |
+
librosa
|
5 |
+
ninja
|
6 |
+
numpy<2.0
|
7 |
+
pandas
|
8 |
+
pesq
|
9 |
+
pillow
|
10 |
+
protobuf
|
11 |
+
pyarrow
|
12 |
+
pyroomacoustics
|
13 |
+
pystoi
|
14 |
+
pytorch-lightning
|
15 |
+
scipy
|
16 |
+
sdeint
|
17 |
+
setuptools
|
18 |
+
seaborn
|
19 |
+
torch
|
20 |
+
torch-ema
|
21 |
+
torchaudio
|
22 |
+
torchvision
|
23 |
+
torchinfo
|
24 |
+
torchsde
|
25 |
+
tqdm
|
26 |
+
wandb
|
requirements_version.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h5py==3.10.0
|
2 |
+
ipympl==0.9.3
|
3 |
+
librosa==0.10.1
|
4 |
+
ninja==1.11.1.1
|
5 |
+
numpy==1.24.4
|
6 |
+
pandas==2.0.3
|
7 |
+
pesq==0.0.4
|
8 |
+
pillow==10.2.0
|
9 |
+
protobuf==4.25.2
|
10 |
+
pyarrow==15.0.0
|
11 |
+
pyroomacoustics==0.7.3
|
12 |
+
pystoi==0.4.1
|
13 |
+
pytorch-lightning==2.1.4
|
14 |
+
scipy==1.10.1
|
15 |
+
sdeint==0.3.0
|
16 |
+
setuptools==44.0.0
|
17 |
+
seaborn==0.13.2
|
18 |
+
torch==2.2.0
|
19 |
+
torch-ema==0.3
|
20 |
+
torchaudio==2.2.0
|
21 |
+
torchvision==0.17.0
|
22 |
+
torchinfo==1.8.0
|
23 |
+
torchsde==0.2.6
|
24 |
+
tqdm==4.66.1
|
25 |
+
wandb==0.16.2
|
sgmse/backbones/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .shared import BackboneRegistry
|
2 |
+
from .ncsnpp import NCSNpp
|
3 |
+
from .ncsnpp_48k import NCSNpp_48k
|
4 |
+
from .dcunet import DCUNet
|
5 |
+
|
6 |
+
__all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet']
|
sgmse/backbones/dcunet.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
from .shared import BackboneRegistry, ComplexConv2d, ComplexConvTranspose2d, ComplexLinear, \
|
9 |
+
DiffusionStepEmbedding, GaussianFourierProjection, FeatureMapDense, torch_complex_from_reim
|
10 |
+
|
11 |
+
|
12 |
+
def get_activation(name):
|
13 |
+
if name == "silu":
|
14 |
+
return nn.SiLU
|
15 |
+
elif name == "relu":
|
16 |
+
return nn.ReLU
|
17 |
+
elif name == "leaky_relu":
|
18 |
+
return nn.LeakyReLU
|
19 |
+
else:
|
20 |
+
raise NotImplementedError(f"Unknown activation: {name}")
|
21 |
+
|
22 |
+
|
23 |
+
class BatchNorm(_BatchNorm):
|
24 |
+
def _check_input_dim(self, input):
|
25 |
+
if input.dim() < 2 or input.dim() > 4:
|
26 |
+
raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
|
27 |
+
|
28 |
+
|
29 |
+
class OnReIm(nn.Module):
|
30 |
+
def __init__(self, module_cls, *args, **kwargs):
|
31 |
+
super().__init__()
|
32 |
+
self.re_module = module_cls(*args, **kwargs)
|
33 |
+
self.im_module = module_cls(*args, **kwargs)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
|
37 |
+
|
38 |
+
|
39 |
+
# Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
|
40 |
+
|
41 |
+
def unet_decoder_args(encoders, *, skip_connections):
|
42 |
+
"""Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
|
43 |
+
given the arguments used to construct the encoder.
|
44 |
+
Args:
|
45 |
+
encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
|
46 |
+
List of arguments used to construct the encoders
|
47 |
+
skip_connections (bool): Whether to include skip connections in the
|
48 |
+
calculation of decoder input channels.
|
49 |
+
Return:
|
50 |
+
tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
|
51 |
+
Arguments to be used to construct decoders
|
52 |
+
"""
|
53 |
+
decoder_args = []
|
54 |
+
for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation in reversed(encoders):
|
55 |
+
if skip_connections and decoder_args:
|
56 |
+
skip_in_chan = enc_out_chan
|
57 |
+
else:
|
58 |
+
skip_in_chan = 0
|
59 |
+
decoder_args.append(
|
60 |
+
(enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation)
|
61 |
+
)
|
62 |
+
return tuple(decoder_args)
|
63 |
+
|
64 |
+
|
65 |
+
def make_unet_encoder_decoder_args(encoder_args, decoder_args):
|
66 |
+
encoder_args = tuple(
|
67 |
+
(
|
68 |
+
in_chan,
|
69 |
+
out_chan,
|
70 |
+
tuple(kernel_size),
|
71 |
+
tuple(stride),
|
72 |
+
tuple([n // 2 for n in kernel_size]) if padding == "auto" else tuple(padding),
|
73 |
+
tuple(dilation)
|
74 |
+
)
|
75 |
+
for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
|
76 |
+
)
|
77 |
+
|
78 |
+
if decoder_args == "auto":
|
79 |
+
decoder_args = unet_decoder_args(
|
80 |
+
encoder_args,
|
81 |
+
skip_connections=True,
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
decoder_args = tuple(
|
85 |
+
(
|
86 |
+
in_chan,
|
87 |
+
out_chan,
|
88 |
+
tuple(kernel_size),
|
89 |
+
tuple(stride),
|
90 |
+
tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
|
91 |
+
tuple(dilation),
|
92 |
+
output_padding,
|
93 |
+
)
|
94 |
+
for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
|
95 |
+
)
|
96 |
+
|
97 |
+
return encoder_args, decoder_args
|
98 |
+
|
99 |
+
|
100 |
+
DCUNET_ARCHITECTURES = {
|
101 |
+
"DCUNet-10": make_unet_encoder_decoder_args(
|
102 |
+
# Encoders:
|
103 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
104 |
+
(
|
105 |
+
(1, 32, (7, 5), (2, 2), "auto", (1,1)),
|
106 |
+
(32, 64, (7, 5), (2, 2), "auto", (1,1)),
|
107 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
108 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
109 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
110 |
+
),
|
111 |
+
# Decoders: automatic inverse
|
112 |
+
"auto",
|
113 |
+
),
|
114 |
+
"DCUNet-16": make_unet_encoder_decoder_args(
|
115 |
+
# Encoders:
|
116 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
117 |
+
(
|
118 |
+
(1, 32, (7, 5), (2, 2), "auto", (1,1)),
|
119 |
+
(32, 32, (7, 5), (2, 1), "auto", (1,1)),
|
120 |
+
(32, 64, (7, 5), (2, 2), "auto", (1,1)),
|
121 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
122 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
123 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
124 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
125 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
126 |
+
),
|
127 |
+
# Decoders: automatic inverse
|
128 |
+
"auto",
|
129 |
+
),
|
130 |
+
"DCUNet-20": make_unet_encoder_decoder_args(
|
131 |
+
# Encoders:
|
132 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
133 |
+
(
|
134 |
+
(1, 32, (7, 1), (1, 1), "auto", (1,1)),
|
135 |
+
(32, 32, (1, 7), (1, 1), "auto", (1,1)),
|
136 |
+
(32, 64, (7, 5), (2, 2), "auto", (1,1)),
|
137 |
+
(64, 64, (7, 5), (2, 1), "auto", (1,1)),
|
138 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
139 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
140 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
141 |
+
(64, 64, (5, 3), (2, 1), "auto", (1,1)),
|
142 |
+
(64, 64, (5, 3), (2, 2), "auto", (1,1)),
|
143 |
+
(64, 90, (5, 3), (2, 1), "auto", (1,1)),
|
144 |
+
),
|
145 |
+
# Decoders: automatic inverse
|
146 |
+
"auto",
|
147 |
+
),
|
148 |
+
"DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper
|
149 |
+
# Encoders:
|
150 |
+
# (in_chan, out_chan, kernel_size, stride, padding, dilation)
|
151 |
+
(
|
152 |
+
(1, 32, (4, 4), (1, 1), "auto", (1, 1)),
|
153 |
+
(32, 32, (4, 4), (1, 1), "auto", (1, 1)),
|
154 |
+
(32, 32, (4, 4), (1, 1), "auto", (1, 1)),
|
155 |
+
(32, 64, (4, 4), (2, 1), "auto", (2, 1)),
|
156 |
+
(64, 128, (4, 4), (2, 2), "auto", (4, 1)),
|
157 |
+
(128, 256, (4, 4), (2, 2), "auto", (8, 1)),
|
158 |
+
),
|
159 |
+
# Decoders: automatic inverse
|
160 |
+
"auto",
|
161 |
+
),
|
162 |
+
}
|
163 |
+
|
164 |
+
|
165 |
+
@BackboneRegistry.register("dcunet")
|
166 |
+
class DCUNet(nn.Module):
|
167 |
+
@staticmethod
|
168 |
+
def add_argparse_args(parser):
|
169 |
+
parser.add_argument("--dcunet-architecture", type=str, default="DilDCUNet-v2", choices=DCUNET_ARCHITECTURES.keys(), help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.")
|
170 |
+
parser.add_argument("--dcunet-time-embedding", type=str, choices=("gfp", "ds", "none"), default="gfp", help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.")
|
171 |
+
parser.add_argument("--dcunet-temb-layers-global", type=int, default=1, help="Number of global linear+activation layers for the time embedding. 1 by default.")
|
172 |
+
parser.add_argument("--dcunet-temb-layers-local", type=int, default=1, help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.")
|
173 |
+
parser.add_argument("--dcunet-temb-activation", type=str, default="silu", help="The (complex) activation to use between all (global&local) time embedding layers.")
|
174 |
+
parser.add_argument("--dcunet-time-embedding-complex", action="store_true", help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.")
|
175 |
+
parser.add_argument("--dcunet-fix-length", type=str, default="pad", choices=("pad", "trim", "none"), help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.")
|
176 |
+
parser.add_argument("--dcunet-mask-bound", type=str, choices=("tanh", "sigmoid", "none"), default="none", help="DCUNet output bounding strategy. 'none' by default.")
|
177 |
+
parser.add_argument("--dcunet-norm-type", type=str, choices=("bN", "CbN"), default="bN", help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.")
|
178 |
+
parser.add_argument("--dcunet-activation", type=str, choices=("leaky_relu", "relu", "silu"), default="leaky_relu", help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.")
|
179 |
+
return parser
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
dcunet_architecture: str = "DilDCUNet-v2",
|
184 |
+
dcunet_time_embedding: str = "gfp",
|
185 |
+
dcunet_temb_layers_global: int = 2,
|
186 |
+
dcunet_temb_layers_local: int = 1,
|
187 |
+
dcunet_temb_activation: str = "silu",
|
188 |
+
dcunet_time_embedding_complex: bool = False,
|
189 |
+
dcunet_fix_length: str = "pad",
|
190 |
+
dcunet_mask_bound: str = "none",
|
191 |
+
dcunet_norm_type: str = "bN",
|
192 |
+
dcunet_activation: str = "relu",
|
193 |
+
embed_dim: int = 128,
|
194 |
+
**kwargs
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.architecture = dcunet_architecture
|
199 |
+
self.fix_length_mode = (dcunet_fix_length if dcunet_fix_length != "none" else None)
|
200 |
+
self.norm_type = dcunet_norm_type
|
201 |
+
self.activation = dcunet_activation
|
202 |
+
self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
|
203 |
+
self.time_embedding = (dcunet_time_embedding if dcunet_time_embedding != "none" else None)
|
204 |
+
self.time_embedding_complex = dcunet_time_embedding_complex
|
205 |
+
self.temb_layers_global = dcunet_temb_layers_global
|
206 |
+
self.temb_layers_local = dcunet_temb_layers_local
|
207 |
+
self.temb_activation = dcunet_temb_activation
|
208 |
+
conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
|
209 |
+
|
210 |
+
# Replace `input_channels` in encoders config
|
211 |
+
_replaced_input_channels, *rest = conf_encoders[0]
|
212 |
+
encoders = ((self.input_channels, *rest), *conf_encoders[1:])
|
213 |
+
decoders = conf_decoders
|
214 |
+
self.encoders_stride_product = np.prod(
|
215 |
+
[enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
|
216 |
+
)
|
217 |
+
|
218 |
+
# Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
|
219 |
+
encoder_decoder_kwargs = dict(
|
220 |
+
norm_type=self.norm_type, activation=self.activation,
|
221 |
+
temb_layers=self.temb_layers_local, temb_activation=self.temb_activation)
|
222 |
+
|
223 |
+
# Instantiate (global) time embedding layer
|
224 |
+
embed_ops = []
|
225 |
+
if self.time_embedding is not None:
|
226 |
+
complex_valued = self.time_embedding_complex
|
227 |
+
if self.time_embedding == "gfp":
|
228 |
+
embed_ops += [GaussianFourierProjection(embed_dim=embed_dim, complex_valued=complex_valued)]
|
229 |
+
encoder_decoder_kwargs["embed_dim"] = embed_dim
|
230 |
+
elif self.time_embedding == "ds":
|
231 |
+
embed_ops += [DiffusionStepEmbedding(embed_dim=embed_dim, complex_valued=complex_valued)]
|
232 |
+
encoder_decoder_kwargs["embed_dim"] = embed_dim
|
233 |
+
|
234 |
+
if self.time_embedding_complex:
|
235 |
+
assert self.time_embedding in ("gfp", "ds"), "Complex timestep embedding only available for gfp and ds"
|
236 |
+
encoder_decoder_kwargs["complex_time_embedding"] = True
|
237 |
+
for _ in range(self.temb_layers_global):
|
238 |
+
embed_ops += [
|
239 |
+
ComplexLinear(embed_dim, embed_dim, complex_valued=True),
|
240 |
+
OnReIm(get_activation(dcunet_temb_activation))
|
241 |
+
]
|
242 |
+
self.embed = nn.Sequential(*embed_ops)
|
243 |
+
|
244 |
+
### Instantiate DCUNet layers ###
|
245 |
+
output_layer = ComplexConvTranspose2d(*decoders[-1])
|
246 |
+
encoders = [DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) for args in encoders]
|
247 |
+
decoders = [DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) for args in decoders[:-1]]
|
248 |
+
|
249 |
+
self.mask_bound = (dcunet_mask_bound if dcunet_mask_bound != "none" else None)
|
250 |
+
if self.mask_bound is not None:
|
251 |
+
raise NotImplementedError("sorry, mask bounding not implemented at the moment")
|
252 |
+
# TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
|
253 |
+
#operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
|
254 |
+
#output_layer = nn.Sequential(*[x for x in operations if x is not None])
|
255 |
+
|
256 |
+
assert len(encoders) == len(decoders) + 1
|
257 |
+
self.encoders = nn.ModuleList(encoders)
|
258 |
+
self.decoders = nn.ModuleList(decoders)
|
259 |
+
self.output_layer = output_layer or nn.Identity()
|
260 |
+
|
261 |
+
def forward(self, spec, t) -> Tensor:
|
262 |
+
"""
|
263 |
+
Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
|
264 |
+
by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
|
265 |
+
and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
|
266 |
+
strides of the encoders.
|
267 |
+
Args:
|
268 |
+
spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
|
269 |
+
Returns:
|
270 |
+
Tensor, of shape (batch, time) or (time).
|
271 |
+
"""
|
272 |
+
# TF-rep shape: (batch, self.input_channels, n_fft, frames)
|
273 |
+
# Estimate mask from time-frequency representation.
|
274 |
+
x_in = self.fix_input_dims(spec)
|
275 |
+
x = x_in
|
276 |
+
t_embed = self.embed(t+0j) if self.time_embedding is not None else None
|
277 |
+
|
278 |
+
enc_outs = []
|
279 |
+
for idx, enc in enumerate(self.encoders):
|
280 |
+
x = enc(x, t_embed)
|
281 |
+
# UNet skip connection
|
282 |
+
enc_outs.append(x)
|
283 |
+
for (enc_out, dec) in zip(reversed(enc_outs[:-1]), self.decoders):
|
284 |
+
x = dec(x, t_embed, output_size=enc_out.shape)
|
285 |
+
x = torch.cat([x, enc_out], dim=1)
|
286 |
+
|
287 |
+
output = self.output_layer(x, output_size=x_in.shape)
|
288 |
+
# output shape: (batch, 1, n_fft, frames)
|
289 |
+
output = self.fix_output_dims(output, spec)
|
290 |
+
return output
|
291 |
+
|
292 |
+
def fix_input_dims(self, x):
|
293 |
+
return _fix_dcu_input_dims(
|
294 |
+
self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
|
295 |
+
)
|
296 |
+
|
297 |
+
def fix_output_dims(self, out, x):
|
298 |
+
return _fix_dcu_output_dims(self.fix_length_mode, out, x)
|
299 |
+
|
300 |
+
|
301 |
+
def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
|
302 |
+
"""Pad or trim `x` to a length compatible with DCUNet."""
|
303 |
+
freq_prod = int(encoders_stride_product[0])
|
304 |
+
time_prod = int(encoders_stride_product[1])
|
305 |
+
if (x.shape[2] - 1) % freq_prod:
|
306 |
+
raise TypeError(
|
307 |
+
f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
|
308 |
+
f"{freq_prod}, got {x.shape} instead"
|
309 |
+
)
|
310 |
+
time_remainder = (x.shape[3] - 1) % time_prod
|
311 |
+
if time_remainder:
|
312 |
+
if fix_length_mode is None:
|
313 |
+
raise TypeError(
|
314 |
+
f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
|
315 |
+
f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
|
316 |
+
f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
|
317 |
+
)
|
318 |
+
elif fix_length_mode == "pad":
|
319 |
+
pad_shape = [0, time_prod - time_remainder]
|
320 |
+
x = nn.functional.pad(x, pad_shape, mode="constant")
|
321 |
+
elif fix_length_mode == "trim":
|
322 |
+
pad_shape = [0, -time_remainder]
|
323 |
+
x = nn.functional.pad(x, pad_shape, mode="constant")
|
324 |
+
else:
|
325 |
+
raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
def _fix_dcu_output_dims(fix_length_mode, out, x):
|
330 |
+
"""Fix shape of `out` to the original shape of `x` by padding/cropping."""
|
331 |
+
inp_len = x.shape[-1]
|
332 |
+
output_len = out.shape[-1]
|
333 |
+
return nn.functional.pad(out, [0, inp_len - output_len])
|
334 |
+
|
335 |
+
|
336 |
+
def _get_norm(norm_type):
|
337 |
+
if norm_type == "CbN":
|
338 |
+
return ComplexBatchNorm
|
339 |
+
elif norm_type == "bN":
|
340 |
+
return partial(OnReIm, BatchNorm)
|
341 |
+
else:
|
342 |
+
raise NotImplementedError(f"Unknown norm type: {norm_type}")
|
343 |
+
|
344 |
+
|
345 |
+
class DCUNetComplexEncoderBlock(nn.Module):
|
346 |
+
def __init__(
|
347 |
+
self,
|
348 |
+
in_chan,
|
349 |
+
out_chan,
|
350 |
+
kernel_size,
|
351 |
+
stride,
|
352 |
+
padding,
|
353 |
+
dilation,
|
354 |
+
norm_type="bN",
|
355 |
+
activation="leaky_relu",
|
356 |
+
embed_dim=None,
|
357 |
+
complex_time_embedding=False,
|
358 |
+
temb_layers=1,
|
359 |
+
temb_activation="silu"
|
360 |
+
):
|
361 |
+
super().__init__()
|
362 |
+
|
363 |
+
self.in_chan = in_chan
|
364 |
+
self.out_chan = out_chan
|
365 |
+
self.kernel_size = kernel_size
|
366 |
+
self.stride = stride
|
367 |
+
self.padding = padding
|
368 |
+
self.dilation = dilation
|
369 |
+
self.temb_layers = temb_layers
|
370 |
+
self.temb_activation = temb_activation
|
371 |
+
self.complex_time_embedding = complex_time_embedding
|
372 |
+
|
373 |
+
self.conv = ComplexConv2d(
|
374 |
+
in_chan, out_chan, kernel_size, stride, padding, bias=norm_type is None, dilation=dilation
|
375 |
+
)
|
376 |
+
self.norm = _get_norm(norm_type)(out_chan)
|
377 |
+
self.activation = OnReIm(get_activation(activation))
|
378 |
+
self.embed_dim = embed_dim
|
379 |
+
if self.embed_dim is not None:
|
380 |
+
ops = []
|
381 |
+
for _ in range(max(0, self.temb_layers - 1)):
|
382 |
+
ops += [
|
383 |
+
ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
|
384 |
+
OnReIm(get_activation(self.temb_activation))
|
385 |
+
]
|
386 |
+
ops += [
|
387 |
+
FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
|
388 |
+
OnReIm(get_activation(self.temb_activation))
|
389 |
+
]
|
390 |
+
self.embed_layer = nn.Sequential(*ops)
|
391 |
+
|
392 |
+
def forward(self, x, t_embed):
|
393 |
+
y = self.conv(x)
|
394 |
+
if self.embed_dim is not None:
|
395 |
+
y = y + self.embed_layer(t_embed)
|
396 |
+
return self.activation(self.norm(y))
|
397 |
+
|
398 |
+
|
399 |
+
class DCUNetComplexDecoderBlock(nn.Module):
|
400 |
+
def __init__(
|
401 |
+
self,
|
402 |
+
in_chan,
|
403 |
+
out_chan,
|
404 |
+
kernel_size,
|
405 |
+
stride,
|
406 |
+
padding,
|
407 |
+
dilation,
|
408 |
+
output_padding=(0, 0),
|
409 |
+
norm_type="bN",
|
410 |
+
activation="leaky_relu",
|
411 |
+
embed_dim=None,
|
412 |
+
temb_layers=1,
|
413 |
+
temb_activation='swish',
|
414 |
+
complex_time_embedding=False,
|
415 |
+
):
|
416 |
+
super().__init__()
|
417 |
+
|
418 |
+
self.in_chan = in_chan
|
419 |
+
self.out_chan = out_chan
|
420 |
+
self.kernel_size = kernel_size
|
421 |
+
self.stride = stride
|
422 |
+
self.padding = padding
|
423 |
+
self.dilation = dilation
|
424 |
+
self.output_padding = output_padding
|
425 |
+
self.complex_time_embedding = complex_time_embedding
|
426 |
+
self.temb_layers = temb_layers
|
427 |
+
self.temb_activation = temb_activation
|
428 |
+
|
429 |
+
self.deconv = ComplexConvTranspose2d(
|
430 |
+
in_chan, out_chan, kernel_size, stride, padding, output_padding, dilation=dilation, bias=norm_type is None
|
431 |
+
)
|
432 |
+
self.norm = _get_norm(norm_type)(out_chan)
|
433 |
+
self.activation = OnReIm(get_activation(activation))
|
434 |
+
self.embed_dim = embed_dim
|
435 |
+
if self.embed_dim is not None:
|
436 |
+
ops = []
|
437 |
+
for _ in range(max(0, self.temb_layers - 1)):
|
438 |
+
ops += [
|
439 |
+
ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
|
440 |
+
OnReIm(get_activation(self.temb_activation))
|
441 |
+
]
|
442 |
+
ops += [
|
443 |
+
FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
|
444 |
+
OnReIm(get_activation(self.temb_activation))
|
445 |
+
]
|
446 |
+
self.embed_layer = nn.Sequential(*ops)
|
447 |
+
|
448 |
+
def forward(self, x, t_embed, output_size=None):
|
449 |
+
y = self.deconv(x, output_size=output_size)
|
450 |
+
if self.embed_dim is not None:
|
451 |
+
y = y + self.embed_layer(t_embed)
|
452 |
+
return self.activation(self.norm(y))
|
453 |
+
|
454 |
+
|
455 |
+
# From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
|
456 |
+
class ComplexBatchNorm(torch.nn.Module):
|
457 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=False):
|
458 |
+
super(ComplexBatchNorm, self).__init__()
|
459 |
+
self.num_features = num_features
|
460 |
+
self.eps = eps
|
461 |
+
self.momentum = momentum
|
462 |
+
self.affine = affine
|
463 |
+
self.track_running_stats = track_running_stats
|
464 |
+
if self.affine:
|
465 |
+
self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
|
466 |
+
self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
|
467 |
+
self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
|
468 |
+
self.Br = torch.nn.Parameter(torch.Tensor(num_features))
|
469 |
+
self.Bi = torch.nn.Parameter(torch.Tensor(num_features))
|
470 |
+
else:
|
471 |
+
self.register_parameter('Wrr', None)
|
472 |
+
self.register_parameter('Wri', None)
|
473 |
+
self.register_parameter('Wii', None)
|
474 |
+
self.register_parameter('Br', None)
|
475 |
+
self.register_parameter('Bi', None)
|
476 |
+
if self.track_running_stats:
|
477 |
+
self.register_buffer('RMr', torch.zeros(num_features))
|
478 |
+
self.register_buffer('RMi', torch.zeros(num_features))
|
479 |
+
self.register_buffer('RVrr', torch.ones (num_features))
|
480 |
+
self.register_buffer('RVri', torch.zeros(num_features))
|
481 |
+
self.register_buffer('RVii', torch.ones (num_features))
|
482 |
+
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
483 |
+
else:
|
484 |
+
self.register_parameter('RMr', None)
|
485 |
+
self.register_parameter('RMi', None)
|
486 |
+
self.register_parameter('RVrr', None)
|
487 |
+
self.register_parameter('RVri', None)
|
488 |
+
self.register_parameter('RVii', None)
|
489 |
+
self.register_parameter('num_batches_tracked', None)
|
490 |
+
self.reset_parameters()
|
491 |
+
|
492 |
+
def reset_running_stats(self):
|
493 |
+
if self.track_running_stats:
|
494 |
+
self.RMr.zero_()
|
495 |
+
self.RMi.zero_()
|
496 |
+
self.RVrr.fill_(1)
|
497 |
+
self.RVri.zero_()
|
498 |
+
self.RVii.fill_(1)
|
499 |
+
self.num_batches_tracked.zero_()
|
500 |
+
|
501 |
+
def reset_parameters(self):
|
502 |
+
self.reset_running_stats()
|
503 |
+
if self.affine:
|
504 |
+
self.Br.data.zero_()
|
505 |
+
self.Bi.data.zero_()
|
506 |
+
self.Wrr.data.fill_(1)
|
507 |
+
self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
|
508 |
+
self.Wii.data.fill_(1)
|
509 |
+
|
510 |
+
def _check_input_dim(self, xr, xi):
|
511 |
+
assert(xr.shape == xi.shape)
|
512 |
+
assert(xr.size(1) == self.num_features)
|
513 |
+
|
514 |
+
def forward(self, x):
|
515 |
+
xr, xi = x.real, x.imag
|
516 |
+
self._check_input_dim(xr, xi)
|
517 |
+
|
518 |
+
exponential_average_factor = 0.0
|
519 |
+
|
520 |
+
if self.training and self.track_running_stats:
|
521 |
+
self.num_batches_tracked += 1
|
522 |
+
if self.momentum is None: # use cumulative moving average
|
523 |
+
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
|
524 |
+
else: # use exponential moving average
|
525 |
+
exponential_average_factor = self.momentum
|
526 |
+
|
527 |
+
#
|
528 |
+
# NOTE: The precise meaning of the "training flag" is:
|
529 |
+
# True: Normalize using batch statistics, update running statistics
|
530 |
+
# if they are being collected.
|
531 |
+
# False: Normalize using running statistics, ignore batch statistics.
|
532 |
+
#
|
533 |
+
training = self.training or not self.track_running_stats
|
534 |
+
redux = [i for i in reversed(range(xr.dim())) if i!=1]
|
535 |
+
vdim = [1] * xr.dim()
|
536 |
+
vdim[1] = xr.size(1)
|
537 |
+
|
538 |
+
#
|
539 |
+
# Mean M Computation and Centering
|
540 |
+
#
|
541 |
+
# Includes running mean update if training and running.
|
542 |
+
#
|
543 |
+
if training:
|
544 |
+
Mr, Mi = xr, xi
|
545 |
+
for d in redux:
|
546 |
+
Mr = Mr.mean(d, keepdim=True)
|
547 |
+
Mi = Mi.mean(d, keepdim=True)
|
548 |
+
if self.track_running_stats:
|
549 |
+
self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
|
550 |
+
self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
|
551 |
+
else:
|
552 |
+
Mr = self.RMr.view(vdim)
|
553 |
+
Mi = self.RMi.view(vdim)
|
554 |
+
xr, xi = xr-Mr, xi-Mi
|
555 |
+
|
556 |
+
#
|
557 |
+
# Variance Matrix V Computation
|
558 |
+
#
|
559 |
+
# Includes epsilon numerical stabilizer/Tikhonov regularizer.
|
560 |
+
# Includes running variance update if training and running.
|
561 |
+
#
|
562 |
+
if training:
|
563 |
+
Vrr = xr * xr
|
564 |
+
Vri = xr * xi
|
565 |
+
Vii = xi * xi
|
566 |
+
for d in redux:
|
567 |
+
Vrr = Vrr.mean(d, keepdim=True)
|
568 |
+
Vri = Vri.mean(d, keepdim=True)
|
569 |
+
Vii = Vii.mean(d, keepdim=True)
|
570 |
+
if self.track_running_stats:
|
571 |
+
self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
|
572 |
+
self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
|
573 |
+
self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
|
574 |
+
else:
|
575 |
+
Vrr = self.RVrr.view(vdim)
|
576 |
+
Vri = self.RVri.view(vdim)
|
577 |
+
Vii = self.RVii.view(vdim)
|
578 |
+
Vrr = Vrr + self.eps
|
579 |
+
Vri = Vri
|
580 |
+
Vii = Vii + self.eps
|
581 |
+
|
582 |
+
#
|
583 |
+
# Matrix Inverse Square Root U = V^-0.5
|
584 |
+
#
|
585 |
+
# sqrt of a 2x2 matrix,
|
586 |
+
# - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
|
587 |
+
tau = Vrr + Vii
|
588 |
+
delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
|
589 |
+
s = delta.sqrt()
|
590 |
+
t = (tau + 2*s).sqrt()
|
591 |
+
|
592 |
+
# matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
|
593 |
+
rst = (s * t).reciprocal()
|
594 |
+
Urr = (s + Vii) * rst
|
595 |
+
Uii = (s + Vrr) * rst
|
596 |
+
Uri = ( - Vri) * rst
|
597 |
+
|
598 |
+
#
|
599 |
+
# Optionally left-multiply U by affine weights W to produce combined
|
600 |
+
# weights Z, left-multiply the inputs by Z, then optionally bias them.
|
601 |
+
#
|
602 |
+
# y = Zx + B
|
603 |
+
# y = WUx + B
|
604 |
+
# y = [Wrr Wri][Urr Uri] [xr] + [Br]
|
605 |
+
# [Wir Wii][Uir Uii] [xi] [Bi]
|
606 |
+
#
|
607 |
+
if self.affine:
|
608 |
+
Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
|
609 |
+
Zrr = (Wrr * Urr) + (Wri * Uri)
|
610 |
+
Zri = (Wrr * Uri) + (Wri * Uii)
|
611 |
+
Zir = (Wri * Urr) + (Wii * Uri)
|
612 |
+
Zii = (Wri * Uri) + (Wii * Uii)
|
613 |
+
else:
|
614 |
+
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
|
615 |
+
|
616 |
+
yr = (Zrr * xr) + (Zri * xi)
|
617 |
+
yi = (Zir * xr) + (Zii * xi)
|
618 |
+
|
619 |
+
if self.affine:
|
620 |
+
yr = yr + self.Br.view(vdim)
|
621 |
+
yi = yi + self.Bi.view(vdim)
|
622 |
+
|
623 |
+
return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
|
624 |
+
|
625 |
+
def extra_repr(self):
|
626 |
+
return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
|
627 |
+
'track_running_stats={track_running_stats}'.format(**self.__dict__)
|
sgmse/backbones/ncsnpp.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
|
18 |
+
from .ncsnpp_utils import layers, layerspp, normalization
|
19 |
+
import torch.nn as nn
|
20 |
+
import functools
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .shared import BackboneRegistry
|
25 |
+
|
26 |
+
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
|
27 |
+
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
|
28 |
+
Combine = layerspp.Combine
|
29 |
+
conv3x3 = layerspp.conv3x3
|
30 |
+
conv1x1 = layerspp.conv1x1
|
31 |
+
get_act = layers.get_act
|
32 |
+
get_normalization = normalization.get_normalization
|
33 |
+
default_initializer = layers.default_init
|
34 |
+
|
35 |
+
|
36 |
+
@BackboneRegistry.register("ncsnpp")
|
37 |
+
class NCSNpp(nn.Module):
|
38 |
+
"""NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_argparse_args(parser):
|
42 |
+
parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
|
43 |
+
parser.add_argument("--num_res_blocks", type=int, default=2)
|
44 |
+
parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
|
45 |
+
parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
|
46 |
+
parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
|
47 |
+
parser.set_defaults(centered=True)
|
48 |
+
return parser
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
scale_by_sigma = True,
|
52 |
+
nonlinearity = 'swish',
|
53 |
+
nf = 128,
|
54 |
+
ch_mult = (1, 1, 2, 2, 2, 2, 2),
|
55 |
+
num_res_blocks = 2,
|
56 |
+
attn_resolutions = (16,),
|
57 |
+
resamp_with_conv = True,
|
58 |
+
conditional = True,
|
59 |
+
fir = True,
|
60 |
+
fir_kernel = [1, 3, 3, 1],
|
61 |
+
skip_rescale = True,
|
62 |
+
resblock_type = 'biggan',
|
63 |
+
progressive = 'output_skip',
|
64 |
+
progressive_input = 'input_skip',
|
65 |
+
progressive_combine = 'sum',
|
66 |
+
init_scale = 0.,
|
67 |
+
fourier_scale = 16,
|
68 |
+
image_size = 256,
|
69 |
+
embedding_type = 'fourier',
|
70 |
+
dropout = .0,
|
71 |
+
centered = True,
|
72 |
+
**unused_kwargs
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
self.act = act = get_act(nonlinearity)
|
76 |
+
|
77 |
+
self.nf = nf = nf
|
78 |
+
ch_mult = ch_mult
|
79 |
+
self.num_res_blocks = num_res_blocks = num_res_blocks
|
80 |
+
self.attn_resolutions = attn_resolutions = attn_resolutions
|
81 |
+
dropout = dropout
|
82 |
+
resamp_with_conv = resamp_with_conv
|
83 |
+
self.num_resolutions = num_resolutions = len(ch_mult)
|
84 |
+
self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
|
85 |
+
|
86 |
+
self.conditional = conditional = conditional # noise-conditional
|
87 |
+
self.centered = centered
|
88 |
+
self.scale_by_sigma = scale_by_sigma
|
89 |
+
|
90 |
+
fir = fir
|
91 |
+
fir_kernel = fir_kernel
|
92 |
+
self.skip_rescale = skip_rescale = skip_rescale
|
93 |
+
self.resblock_type = resblock_type = resblock_type.lower()
|
94 |
+
self.progressive = progressive = progressive.lower()
|
95 |
+
self.progressive_input = progressive_input = progressive_input.lower()
|
96 |
+
self.embedding_type = embedding_type = embedding_type.lower()
|
97 |
+
init_scale = init_scale
|
98 |
+
assert progressive in ['none', 'output_skip', 'residual']
|
99 |
+
assert progressive_input in ['none', 'input_skip', 'residual']
|
100 |
+
assert embedding_type in ['fourier', 'positional']
|
101 |
+
combine_method = progressive_combine.lower()
|
102 |
+
combiner = functools.partial(Combine, method=combine_method)
|
103 |
+
|
104 |
+
num_channels = 4 # x.real, x.imag, y.real, y.imag
|
105 |
+
self.output_layer = nn.Conv2d(num_channels, 2, 1)
|
106 |
+
|
107 |
+
modules = []
|
108 |
+
# timestep/noise_level embedding
|
109 |
+
if embedding_type == 'fourier':
|
110 |
+
# Gaussian Fourier features embeddings.
|
111 |
+
modules.append(layerspp.GaussianFourierProjection(
|
112 |
+
embedding_size=nf, scale=fourier_scale
|
113 |
+
))
|
114 |
+
embed_dim = 2 * nf
|
115 |
+
elif embedding_type == 'positional':
|
116 |
+
embed_dim = nf
|
117 |
+
else:
|
118 |
+
raise ValueError(f'embedding type {embedding_type} unknown.')
|
119 |
+
|
120 |
+
if conditional:
|
121 |
+
modules.append(nn.Linear(embed_dim, nf * 4))
|
122 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
123 |
+
nn.init.zeros_(modules[-1].bias)
|
124 |
+
modules.append(nn.Linear(nf * 4, nf * 4))
|
125 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
126 |
+
nn.init.zeros_(modules[-1].bias)
|
127 |
+
|
128 |
+
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
129 |
+
init_scale=init_scale, skip_rescale=skip_rescale)
|
130 |
+
|
131 |
+
Upsample = functools.partial(layerspp.Upsample,
|
132 |
+
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
133 |
+
|
134 |
+
if progressive == 'output_skip':
|
135 |
+
self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
136 |
+
elif progressive == 'residual':
|
137 |
+
pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
|
138 |
+
fir_kernel=fir_kernel, with_conv=True)
|
139 |
+
|
140 |
+
Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
141 |
+
|
142 |
+
if progressive_input == 'input_skip':
|
143 |
+
self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
144 |
+
elif progressive_input == 'residual':
|
145 |
+
pyramid_downsample = functools.partial(layerspp.Downsample,
|
146 |
+
fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
147 |
+
|
148 |
+
if resblock_type == 'ddpm':
|
149 |
+
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
|
150 |
+
dropout=dropout, init_scale=init_scale,
|
151 |
+
skip_rescale=skip_rescale, temb_dim=nf * 4)
|
152 |
+
|
153 |
+
elif resblock_type == 'biggan':
|
154 |
+
ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
|
155 |
+
dropout=dropout, fir=fir, fir_kernel=fir_kernel,
|
156 |
+
init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
|
157 |
+
|
158 |
+
else:
|
159 |
+
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
160 |
+
|
161 |
+
# Downsampling block
|
162 |
+
|
163 |
+
channels = num_channels
|
164 |
+
if progressive_input != 'none':
|
165 |
+
input_pyramid_ch = channels
|
166 |
+
|
167 |
+
modules.append(conv3x3(channels, nf))
|
168 |
+
hs_c = [nf]
|
169 |
+
|
170 |
+
in_ch = nf
|
171 |
+
for i_level in range(num_resolutions):
|
172 |
+
# Residual blocks for this resolution
|
173 |
+
for i_block in range(num_res_blocks):
|
174 |
+
out_ch = nf * ch_mult[i_level]
|
175 |
+
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
176 |
+
in_ch = out_ch
|
177 |
+
|
178 |
+
if all_resolutions[i_level] in attn_resolutions:
|
179 |
+
modules.append(AttnBlock(channels=in_ch))
|
180 |
+
hs_c.append(in_ch)
|
181 |
+
|
182 |
+
if i_level != num_resolutions - 1:
|
183 |
+
if resblock_type == 'ddpm':
|
184 |
+
modules.append(Downsample(in_ch=in_ch))
|
185 |
+
else:
|
186 |
+
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
187 |
+
|
188 |
+
if progressive_input == 'input_skip':
|
189 |
+
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
190 |
+
if combine_method == 'cat':
|
191 |
+
in_ch *= 2
|
192 |
+
|
193 |
+
elif progressive_input == 'residual':
|
194 |
+
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
195 |
+
input_pyramid_ch = in_ch
|
196 |
+
|
197 |
+
hs_c.append(in_ch)
|
198 |
+
|
199 |
+
in_ch = hs_c[-1]
|
200 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
201 |
+
modules.append(AttnBlock(channels=in_ch))
|
202 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
203 |
+
|
204 |
+
pyramid_ch = 0
|
205 |
+
# Upsampling block
|
206 |
+
for i_level in reversed(range(num_resolutions)):
|
207 |
+
for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
|
208 |
+
out_ch = nf * ch_mult[i_level]
|
209 |
+
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
210 |
+
in_ch = out_ch
|
211 |
+
|
212 |
+
if all_resolutions[i_level] in attn_resolutions:
|
213 |
+
modules.append(AttnBlock(channels=in_ch))
|
214 |
+
|
215 |
+
if progressive != 'none':
|
216 |
+
if i_level == num_resolutions - 1:
|
217 |
+
if progressive == 'output_skip':
|
218 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
219 |
+
num_channels=in_ch, eps=1e-6))
|
220 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
221 |
+
pyramid_ch = channels
|
222 |
+
elif progressive == 'residual':
|
223 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
224 |
+
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
225 |
+
pyramid_ch = in_ch
|
226 |
+
else:
|
227 |
+
raise ValueError(f'{progressive} is not a valid name.')
|
228 |
+
else:
|
229 |
+
if progressive == 'output_skip':
|
230 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
231 |
+
num_channels=in_ch, eps=1e-6))
|
232 |
+
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
233 |
+
pyramid_ch = channels
|
234 |
+
elif progressive == 'residual':
|
235 |
+
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
236 |
+
pyramid_ch = in_ch
|
237 |
+
else:
|
238 |
+
raise ValueError(f'{progressive} is not a valid name')
|
239 |
+
|
240 |
+
if i_level != 0:
|
241 |
+
if resblock_type == 'ddpm':
|
242 |
+
modules.append(Upsample(in_ch=in_ch))
|
243 |
+
else:
|
244 |
+
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
245 |
+
|
246 |
+
assert not hs_c
|
247 |
+
|
248 |
+
if progressive != 'output_skip':
|
249 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
250 |
+
num_channels=in_ch, eps=1e-6))
|
251 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
252 |
+
|
253 |
+
self.all_modules = nn.ModuleList(modules)
|
254 |
+
|
255 |
+
|
256 |
+
def forward(self, x, time_cond):
|
257 |
+
# timestep/noise_level embedding; only for continuous training
|
258 |
+
modules = self.all_modules
|
259 |
+
m_idx = 0
|
260 |
+
|
261 |
+
# Convert real and imaginary parts of (x,y) into four channel dimensions
|
262 |
+
x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
|
263 |
+
x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
|
264 |
+
|
265 |
+
if self.embedding_type == 'fourier':
|
266 |
+
# Gaussian Fourier features embeddings.
|
267 |
+
used_sigmas = time_cond
|
268 |
+
temb = modules[m_idx](torch.log(used_sigmas))
|
269 |
+
m_idx += 1
|
270 |
+
|
271 |
+
elif self.embedding_type == 'positional':
|
272 |
+
# Sinusoidal positional embeddings.
|
273 |
+
timesteps = time_cond
|
274 |
+
used_sigmas = self.sigmas[time_cond.long()]
|
275 |
+
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
276 |
+
|
277 |
+
else:
|
278 |
+
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
279 |
+
|
280 |
+
if self.conditional:
|
281 |
+
temb = modules[m_idx](temb)
|
282 |
+
m_idx += 1
|
283 |
+
temb = modules[m_idx](self.act(temb))
|
284 |
+
m_idx += 1
|
285 |
+
else:
|
286 |
+
temb = None
|
287 |
+
|
288 |
+
if not self.centered:
|
289 |
+
# If input data is in [0, 1]
|
290 |
+
x = 2 * x - 1.
|
291 |
+
|
292 |
+
# Downsampling block
|
293 |
+
input_pyramid = None
|
294 |
+
if self.progressive_input != 'none':
|
295 |
+
input_pyramid = x
|
296 |
+
|
297 |
+
# Input layer: Conv2d: 4ch -> 128ch
|
298 |
+
hs = [modules[m_idx](x)]
|
299 |
+
m_idx += 1
|
300 |
+
|
301 |
+
# Down path in U-Net
|
302 |
+
for i_level in range(self.num_resolutions):
|
303 |
+
# Residual blocks for this resolution
|
304 |
+
for i_block in range(self.num_res_blocks):
|
305 |
+
h = modules[m_idx](hs[-1], temb)
|
306 |
+
m_idx += 1
|
307 |
+
# Attention layer (optional)
|
308 |
+
if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
|
309 |
+
h = modules[m_idx](h)
|
310 |
+
m_idx += 1
|
311 |
+
hs.append(h)
|
312 |
+
|
313 |
+
# Downsampling
|
314 |
+
if i_level != self.num_resolutions - 1:
|
315 |
+
if self.resblock_type == 'ddpm':
|
316 |
+
h = modules[m_idx](hs[-1])
|
317 |
+
m_idx += 1
|
318 |
+
else:
|
319 |
+
h = modules[m_idx](hs[-1], temb)
|
320 |
+
m_idx += 1
|
321 |
+
|
322 |
+
if self.progressive_input == 'input_skip': # Combine h with x
|
323 |
+
input_pyramid = self.pyramid_downsample(input_pyramid)
|
324 |
+
h = modules[m_idx](input_pyramid, h)
|
325 |
+
m_idx += 1
|
326 |
+
|
327 |
+
elif self.progressive_input == 'residual':
|
328 |
+
input_pyramid = modules[m_idx](input_pyramid)
|
329 |
+
m_idx += 1
|
330 |
+
if self.skip_rescale:
|
331 |
+
input_pyramid = (input_pyramid + h) / np.sqrt(2.)
|
332 |
+
else:
|
333 |
+
input_pyramid = input_pyramid + h
|
334 |
+
h = input_pyramid
|
335 |
+
hs.append(h)
|
336 |
+
|
337 |
+
h = hs[-1] # actualy equal to: h = h
|
338 |
+
h = modules[m_idx](h, temb) # ResNet block
|
339 |
+
m_idx += 1
|
340 |
+
h = modules[m_idx](h) # Attention block
|
341 |
+
m_idx += 1
|
342 |
+
h = modules[m_idx](h, temb) # ResNet block
|
343 |
+
m_idx += 1
|
344 |
+
|
345 |
+
pyramid = None
|
346 |
+
|
347 |
+
# Upsampling block
|
348 |
+
for i_level in reversed(range(self.num_resolutions)):
|
349 |
+
for i_block in range(self.num_res_blocks + 1):
|
350 |
+
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
|
351 |
+
m_idx += 1
|
352 |
+
|
353 |
+
# edit: from -1 to -2
|
354 |
+
if h.shape[-2] in self.attn_resolutions:
|
355 |
+
h = modules[m_idx](h)
|
356 |
+
m_idx += 1
|
357 |
+
|
358 |
+
if self.progressive != 'none':
|
359 |
+
if i_level == self.num_resolutions - 1:
|
360 |
+
if self.progressive == 'output_skip':
|
361 |
+
pyramid = self.act(modules[m_idx](h)) # GroupNorm
|
362 |
+
m_idx += 1
|
363 |
+
pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
|
364 |
+
m_idx += 1
|
365 |
+
elif self.progressive == 'residual':
|
366 |
+
pyramid = self.act(modules[m_idx](h))
|
367 |
+
m_idx += 1
|
368 |
+
pyramid = modules[m_idx](pyramid)
|
369 |
+
m_idx += 1
|
370 |
+
else:
|
371 |
+
raise ValueError(f'{self.progressive} is not a valid name.')
|
372 |
+
else:
|
373 |
+
if self.progressive == 'output_skip':
|
374 |
+
pyramid = self.pyramid_upsample(pyramid) # Upsample
|
375 |
+
pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
|
376 |
+
m_idx += 1
|
377 |
+
pyramid_h = modules[m_idx](pyramid_h)
|
378 |
+
m_idx += 1
|
379 |
+
pyramid = pyramid + pyramid_h
|
380 |
+
elif self.progressive == 'residual':
|
381 |
+
pyramid = modules[m_idx](pyramid)
|
382 |
+
m_idx += 1
|
383 |
+
if self.skip_rescale:
|
384 |
+
pyramid = (pyramid + h) / np.sqrt(2.)
|
385 |
+
else:
|
386 |
+
pyramid = pyramid + h
|
387 |
+
h = pyramid
|
388 |
+
else:
|
389 |
+
raise ValueError(f'{self.progressive} is not a valid name')
|
390 |
+
|
391 |
+
# Upsampling Layer
|
392 |
+
if i_level != 0:
|
393 |
+
if self.resblock_type == 'ddpm':
|
394 |
+
h = modules[m_idx](h)
|
395 |
+
m_idx += 1
|
396 |
+
else:
|
397 |
+
h = modules[m_idx](h, temb) # Upspampling
|
398 |
+
m_idx += 1
|
399 |
+
|
400 |
+
assert not hs
|
401 |
+
|
402 |
+
if self.progressive == 'output_skip':
|
403 |
+
h = pyramid
|
404 |
+
else:
|
405 |
+
h = self.act(modules[m_idx](h))
|
406 |
+
m_idx += 1
|
407 |
+
h = modules[m_idx](h)
|
408 |
+
m_idx += 1
|
409 |
+
|
410 |
+
assert m_idx == len(modules), "Implementation error"
|
411 |
+
if self.scale_by_sigma:
|
412 |
+
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
|
413 |
+
h = h / used_sigmas
|
414 |
+
|
415 |
+
# Convert back to complex number
|
416 |
+
h = self.output_layer(h)
|
417 |
+
h = torch.permute(h, (0, 2, 3, 1)).contiguous()
|
418 |
+
h = torch.view_as_complex(h)[:,None, :, :]
|
419 |
+
return h
|
sgmse/backbones/ncsnpp_48k.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
|
18 |
+
from .ncsnpp_utils import layers, layerspp, normalization
|
19 |
+
import torch.nn as nn
|
20 |
+
import functools
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .shared import BackboneRegistry
|
25 |
+
|
26 |
+
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
|
27 |
+
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
|
28 |
+
Combine = layerspp.Combine
|
29 |
+
conv3x3 = layerspp.conv3x3
|
30 |
+
conv1x1 = layerspp.conv1x1
|
31 |
+
get_act = layers.get_act
|
32 |
+
get_normalization = normalization.get_normalization
|
33 |
+
default_initializer = layers.default_init
|
34 |
+
|
35 |
+
|
36 |
+
@BackboneRegistry.register("ncsnpp_48k")
|
37 |
+
class NCSNpp_48k(nn.Module):
|
38 |
+
"""NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_argparse_args(parser):
|
42 |
+
parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
|
43 |
+
parser.add_argument("--num_res_blocks", type=int, default=2)
|
44 |
+
parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
|
45 |
+
parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
|
46 |
+
parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
|
47 |
+
parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
|
48 |
+
parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
|
49 |
+
parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
|
50 |
+
parser.set_defaults(centered=True)
|
51 |
+
return parser
|
52 |
+
|
53 |
+
def __init__(self,
|
54 |
+
scale_by_sigma = True,
|
55 |
+
nonlinearity = 'swish',
|
56 |
+
nf = 128,
|
57 |
+
ch_mult = (1, 1, 2, 2, 2, 2, 2),
|
58 |
+
num_res_blocks = 2,
|
59 |
+
attn_resolutions = (),
|
60 |
+
resamp_with_conv = True,
|
61 |
+
conditional = True,
|
62 |
+
fir = True,
|
63 |
+
fir_kernel = [1, 3, 3, 1],
|
64 |
+
skip_rescale = True,
|
65 |
+
resblock_type = 'biggan',
|
66 |
+
progressive = 'none',
|
67 |
+
progressive_input = 'none',
|
68 |
+
progressive_combine = 'sum',
|
69 |
+
init_scale = 0.,
|
70 |
+
fourier_scale = 16,
|
71 |
+
image_size = 256,
|
72 |
+
embedding_type = 'fourier',
|
73 |
+
dropout = .0,
|
74 |
+
centered = True,
|
75 |
+
**unused_kwargs
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.act = act = get_act(nonlinearity)
|
79 |
+
|
80 |
+
self.nf = nf = nf
|
81 |
+
ch_mult = ch_mult
|
82 |
+
self.num_res_blocks = num_res_blocks = num_res_blocks
|
83 |
+
self.attn_resolutions = attn_resolutions
|
84 |
+
dropout = dropout
|
85 |
+
resamp_with_conv = resamp_with_conv
|
86 |
+
self.num_resolutions = num_resolutions = len(ch_mult)
|
87 |
+
self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
|
88 |
+
|
89 |
+
self.conditional = conditional = conditional # noise-conditional
|
90 |
+
self.centered = centered
|
91 |
+
self.scale_by_sigma = scale_by_sigma
|
92 |
+
|
93 |
+
fir = fir
|
94 |
+
fir_kernel = fir_kernel
|
95 |
+
self.skip_rescale = skip_rescale = skip_rescale
|
96 |
+
self.resblock_type = resblock_type = resblock_type.lower()
|
97 |
+
self.progressive = progressive = progressive.lower()
|
98 |
+
self.progressive_input = progressive_input = progressive_input.lower()
|
99 |
+
self.embedding_type = embedding_type = embedding_type.lower()
|
100 |
+
init_scale = init_scale
|
101 |
+
assert progressive in ['none', 'output_skip', 'residual']
|
102 |
+
assert progressive_input in ['none', 'input_skip', 'residual']
|
103 |
+
assert embedding_type in ['fourier', 'positional']
|
104 |
+
combine_method = progressive_combine.lower()
|
105 |
+
combiner = functools.partial(Combine, method=combine_method)
|
106 |
+
|
107 |
+
num_channels = 4 # x.real, x.imag, y.real, y.imag
|
108 |
+
self.output_layer = nn.Conv2d(num_channels, 2, 1)
|
109 |
+
|
110 |
+
modules = []
|
111 |
+
# timestep/noise_level embedding
|
112 |
+
if embedding_type == 'fourier':
|
113 |
+
# Gaussian Fourier features embeddings.
|
114 |
+
modules.append(layerspp.GaussianFourierProjection(
|
115 |
+
embedding_size=nf, scale=fourier_scale
|
116 |
+
))
|
117 |
+
embed_dim = 2 * nf
|
118 |
+
elif embedding_type == 'positional':
|
119 |
+
embed_dim = nf
|
120 |
+
else:
|
121 |
+
raise ValueError(f'embedding type {embedding_type} unknown.')
|
122 |
+
|
123 |
+
if conditional:
|
124 |
+
modules.append(nn.Linear(embed_dim, nf * 4))
|
125 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
126 |
+
nn.init.zeros_(modules[-1].bias)
|
127 |
+
modules.append(nn.Linear(nf * 4, nf * 4))
|
128 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
129 |
+
nn.init.zeros_(modules[-1].bias)
|
130 |
+
|
131 |
+
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
132 |
+
init_scale=init_scale, skip_rescale=skip_rescale)
|
133 |
+
|
134 |
+
Upsample = functools.partial(layerspp.Upsample,
|
135 |
+
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
136 |
+
|
137 |
+
if progressive == 'output_skip':
|
138 |
+
self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
139 |
+
elif progressive == 'residual':
|
140 |
+
pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
|
141 |
+
fir_kernel=fir_kernel, with_conv=True)
|
142 |
+
|
143 |
+
Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
144 |
+
|
145 |
+
if progressive_input == 'input_skip':
|
146 |
+
self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
147 |
+
elif progressive_input == 'residual':
|
148 |
+
pyramid_downsample = functools.partial(layerspp.Downsample,
|
149 |
+
fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
150 |
+
|
151 |
+
if resblock_type == 'ddpm':
|
152 |
+
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
|
153 |
+
dropout=dropout, init_scale=init_scale,
|
154 |
+
skip_rescale=skip_rescale, temb_dim=nf * 4)
|
155 |
+
|
156 |
+
elif resblock_type == 'biggan':
|
157 |
+
ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
|
158 |
+
dropout=dropout, fir=fir, fir_kernel=fir_kernel,
|
159 |
+
init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
|
160 |
+
|
161 |
+
else:
|
162 |
+
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
163 |
+
|
164 |
+
# Downsampling block
|
165 |
+
|
166 |
+
channels = num_channels
|
167 |
+
if progressive_input != 'none':
|
168 |
+
input_pyramid_ch = channels
|
169 |
+
|
170 |
+
modules.append(conv3x3(channels, nf))
|
171 |
+
hs_c = [nf]
|
172 |
+
|
173 |
+
in_ch = nf
|
174 |
+
for i_level in range(num_resolutions):
|
175 |
+
# Residual blocks for this resolution
|
176 |
+
for i_block in range(num_res_blocks):
|
177 |
+
out_ch = nf * ch_mult[i_level]
|
178 |
+
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
179 |
+
in_ch = out_ch
|
180 |
+
|
181 |
+
if all_resolutions[i_level] in attn_resolutions:
|
182 |
+
modules.append(AttnBlock(channels=in_ch))
|
183 |
+
hs_c.append(in_ch)
|
184 |
+
|
185 |
+
if i_level != num_resolutions - 1:
|
186 |
+
if resblock_type == 'ddpm':
|
187 |
+
modules.append(Downsample(in_ch=in_ch))
|
188 |
+
else:
|
189 |
+
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
190 |
+
|
191 |
+
if progressive_input == 'input_skip':
|
192 |
+
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
193 |
+
if combine_method == 'cat':
|
194 |
+
in_ch *= 2
|
195 |
+
|
196 |
+
elif progressive_input == 'residual':
|
197 |
+
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
198 |
+
input_pyramid_ch = in_ch
|
199 |
+
|
200 |
+
hs_c.append(in_ch)
|
201 |
+
|
202 |
+
in_ch = hs_c[-1]
|
203 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
204 |
+
modules.append(AttnBlock(channels=in_ch))
|
205 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
206 |
+
|
207 |
+
pyramid_ch = 0
|
208 |
+
# Upsampling block
|
209 |
+
for i_level in reversed(range(num_resolutions)):
|
210 |
+
for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
|
211 |
+
out_ch = nf * ch_mult[i_level]
|
212 |
+
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
213 |
+
in_ch = out_ch
|
214 |
+
|
215 |
+
if all_resolutions[i_level] in attn_resolutions:
|
216 |
+
modules.append(AttnBlock(channels=in_ch))
|
217 |
+
|
218 |
+
if progressive != 'none':
|
219 |
+
if i_level == num_resolutions - 1:
|
220 |
+
if progressive == 'output_skip':
|
221 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
222 |
+
num_channels=in_ch, eps=1e-6))
|
223 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
224 |
+
pyramid_ch = channels
|
225 |
+
elif progressive == 'residual':
|
226 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
227 |
+
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
228 |
+
pyramid_ch = in_ch
|
229 |
+
else:
|
230 |
+
raise ValueError(f'{progressive} is not a valid name.')
|
231 |
+
else:
|
232 |
+
if progressive == 'output_skip':
|
233 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
234 |
+
num_channels=in_ch, eps=1e-6))
|
235 |
+
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
236 |
+
pyramid_ch = channels
|
237 |
+
elif progressive == 'residual':
|
238 |
+
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
239 |
+
pyramid_ch = in_ch
|
240 |
+
else:
|
241 |
+
raise ValueError(f'{progressive} is not a valid name')
|
242 |
+
|
243 |
+
if i_level != 0:
|
244 |
+
if resblock_type == 'ddpm':
|
245 |
+
modules.append(Upsample(in_ch=in_ch))
|
246 |
+
else:
|
247 |
+
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
248 |
+
|
249 |
+
assert not hs_c
|
250 |
+
|
251 |
+
if progressive != 'output_skip':
|
252 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
253 |
+
num_channels=in_ch, eps=1e-6))
|
254 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
255 |
+
|
256 |
+
self.all_modules = nn.ModuleList(modules)
|
257 |
+
|
258 |
+
|
259 |
+
def forward(self, x, time_cond):
|
260 |
+
# timestep/noise_level embedding; only for continuous training
|
261 |
+
modules = self.all_modules
|
262 |
+
m_idx = 0
|
263 |
+
|
264 |
+
# Convert real and imaginary parts of (x,y) into four channel dimensions
|
265 |
+
x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
|
266 |
+
x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
|
267 |
+
|
268 |
+
if self.embedding_type == 'fourier':
|
269 |
+
# Gaussian Fourier features embeddings.
|
270 |
+
used_sigmas = time_cond
|
271 |
+
temb = modules[m_idx](torch.log(used_sigmas))
|
272 |
+
m_idx += 1
|
273 |
+
|
274 |
+
elif self.embedding_type == 'positional':
|
275 |
+
# Sinusoidal positional embeddings.
|
276 |
+
timesteps = time_cond
|
277 |
+
used_sigmas = self.sigmas[time_cond.long()]
|
278 |
+
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
279 |
+
|
280 |
+
else:
|
281 |
+
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
282 |
+
|
283 |
+
if self.conditional:
|
284 |
+
temb = modules[m_idx](temb)
|
285 |
+
m_idx += 1
|
286 |
+
temb = modules[m_idx](self.act(temb))
|
287 |
+
m_idx += 1
|
288 |
+
else:
|
289 |
+
temb = None
|
290 |
+
|
291 |
+
if not self.centered:
|
292 |
+
# If input data is in [0, 1]
|
293 |
+
x = 2 * x - 1.
|
294 |
+
|
295 |
+
# Downsampling block
|
296 |
+
input_pyramid = None
|
297 |
+
if self.progressive_input != 'none':
|
298 |
+
input_pyramid = x
|
299 |
+
|
300 |
+
# Input layer: Conv2d: 4ch -> 128ch
|
301 |
+
hs = [modules[m_idx](x)]
|
302 |
+
m_idx += 1
|
303 |
+
|
304 |
+
# Down path in U-Net
|
305 |
+
for i_level in range(self.num_resolutions):
|
306 |
+
# Residual blocks for this resolution
|
307 |
+
for i_block in range(self.num_res_blocks):
|
308 |
+
h = modules[m_idx](hs[-1], temb)
|
309 |
+
m_idx += 1
|
310 |
+
# Attention layer (optional)
|
311 |
+
if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
|
312 |
+
h = modules[m_idx](h)
|
313 |
+
m_idx += 1
|
314 |
+
hs.append(h)
|
315 |
+
|
316 |
+
# Downsampling
|
317 |
+
if i_level != self.num_resolutions - 1:
|
318 |
+
if self.resblock_type == 'ddpm':
|
319 |
+
h = modules[m_idx](hs[-1])
|
320 |
+
m_idx += 1
|
321 |
+
else:
|
322 |
+
h = modules[m_idx](hs[-1], temb)
|
323 |
+
m_idx += 1
|
324 |
+
|
325 |
+
if self.progressive_input == 'input_skip': # Combine h with x
|
326 |
+
input_pyramid = self.pyramid_downsample(input_pyramid)
|
327 |
+
h = modules[m_idx](input_pyramid, h)
|
328 |
+
m_idx += 1
|
329 |
+
|
330 |
+
elif self.progressive_input == 'residual':
|
331 |
+
input_pyramid = modules[m_idx](input_pyramid)
|
332 |
+
m_idx += 1
|
333 |
+
if self.skip_rescale:
|
334 |
+
input_pyramid = (input_pyramid + h) / np.sqrt(2.)
|
335 |
+
else:
|
336 |
+
input_pyramid = input_pyramid + h
|
337 |
+
h = input_pyramid
|
338 |
+
hs.append(h)
|
339 |
+
|
340 |
+
h = hs[-1] # actualy equal to: h = h
|
341 |
+
h = modules[m_idx](h, temb) # ResNet block
|
342 |
+
m_idx += 1
|
343 |
+
h = modules[m_idx](h) # Attention block
|
344 |
+
m_idx += 1
|
345 |
+
h = modules[m_idx](h, temb) # ResNet block
|
346 |
+
m_idx += 1
|
347 |
+
|
348 |
+
pyramid = None
|
349 |
+
|
350 |
+
# Upsampling block
|
351 |
+
for i_level in reversed(range(self.num_resolutions)):
|
352 |
+
for i_block in range(self.num_res_blocks + 1):
|
353 |
+
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
|
354 |
+
m_idx += 1
|
355 |
+
|
356 |
+
# edit: from -1 to -2
|
357 |
+
if h.shape[-2] in self.attn_resolutions:
|
358 |
+
h = modules[m_idx](h)
|
359 |
+
m_idx += 1
|
360 |
+
|
361 |
+
if self.progressive != 'none':
|
362 |
+
if i_level == self.num_resolutions - 1:
|
363 |
+
if self.progressive == 'output_skip':
|
364 |
+
pyramid = self.act(modules[m_idx](h)) # GroupNorm
|
365 |
+
m_idx += 1
|
366 |
+
pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
|
367 |
+
m_idx += 1
|
368 |
+
elif self.progressive == 'residual':
|
369 |
+
pyramid = self.act(modules[m_idx](h))
|
370 |
+
m_idx += 1
|
371 |
+
pyramid = modules[m_idx](pyramid)
|
372 |
+
m_idx += 1
|
373 |
+
else:
|
374 |
+
raise ValueError(f'{self.progressive} is not a valid name.')
|
375 |
+
else:
|
376 |
+
if self.progressive == 'output_skip':
|
377 |
+
pyramid = self.pyramid_upsample(pyramid) # Upsample
|
378 |
+
pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
|
379 |
+
m_idx += 1
|
380 |
+
pyramid_h = modules[m_idx](pyramid_h)
|
381 |
+
m_idx += 1
|
382 |
+
pyramid = pyramid + pyramid_h
|
383 |
+
elif self.progressive == 'residual':
|
384 |
+
pyramid = modules[m_idx](pyramid)
|
385 |
+
m_idx += 1
|
386 |
+
if self.skip_rescale:
|
387 |
+
pyramid = (pyramid + h) / np.sqrt(2.)
|
388 |
+
else:
|
389 |
+
pyramid = pyramid + h
|
390 |
+
h = pyramid
|
391 |
+
else:
|
392 |
+
raise ValueError(f'{self.progressive} is not a valid name')
|
393 |
+
|
394 |
+
# Upsampling Layer
|
395 |
+
if i_level != 0:
|
396 |
+
if self.resblock_type == 'ddpm':
|
397 |
+
h = modules[m_idx](h)
|
398 |
+
m_idx += 1
|
399 |
+
else:
|
400 |
+
h = modules[m_idx](h, temb) # Upspampling
|
401 |
+
m_idx += 1
|
402 |
+
|
403 |
+
assert not hs
|
404 |
+
|
405 |
+
if self.progressive == 'output_skip':
|
406 |
+
h = pyramid
|
407 |
+
else:
|
408 |
+
h = self.act(modules[m_idx](h))
|
409 |
+
m_idx += 1
|
410 |
+
h = modules[m_idx](h)
|
411 |
+
m_idx += 1
|
412 |
+
|
413 |
+
assert m_idx == len(modules), "Implementation error"
|
414 |
+
|
415 |
+
# Convert back to complex number
|
416 |
+
h = self.output_layer(h)
|
417 |
+
|
418 |
+
if self.scale_by_sigma:
|
419 |
+
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
|
420 |
+
h = h / used_sigmas
|
421 |
+
|
422 |
+
h = torch.permute(h, (0, 2, 3, 1)).contiguous()
|
423 |
+
h = torch.view_as_complex(h)[:,None, :, :]
|
424 |
+
return h
|
sgmse/backbones/ncsnpp_utils/layers.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Common layers for defining score networks.
|
18 |
+
"""
|
19 |
+
import math
|
20 |
+
import string
|
21 |
+
from functools import partial
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import numpy as np
|
26 |
+
from .normalization import ConditionalInstanceNorm2dPlus
|
27 |
+
|
28 |
+
|
29 |
+
def get_act(config):
|
30 |
+
"""Get activation functions from the config file."""
|
31 |
+
|
32 |
+
if config == 'elu':
|
33 |
+
return nn.ELU()
|
34 |
+
elif config == 'relu':
|
35 |
+
return nn.ReLU()
|
36 |
+
elif config == 'lrelu':
|
37 |
+
return nn.LeakyReLU(negative_slope=0.2)
|
38 |
+
elif config == 'swish':
|
39 |
+
return nn.SiLU()
|
40 |
+
else:
|
41 |
+
raise NotImplementedError('activation function does not exist!')
|
42 |
+
|
43 |
+
|
44 |
+
def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
|
45 |
+
"""1x1 convolution. Same as NCSNv1/v2."""
|
46 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
47 |
+
padding=padding)
|
48 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
49 |
+
conv.weight.data *= init_scale
|
50 |
+
conv.bias.data *= init_scale
|
51 |
+
return conv
|
52 |
+
|
53 |
+
|
54 |
+
def variance_scaling(scale, mode, distribution,
|
55 |
+
in_axis=1, out_axis=0,
|
56 |
+
dtype=torch.float32,
|
57 |
+
device='cpu'):
|
58 |
+
"""Ported from JAX. """
|
59 |
+
|
60 |
+
def _compute_fans(shape, in_axis=1, out_axis=0):
|
61 |
+
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
62 |
+
fan_in = shape[in_axis] * receptive_field_size
|
63 |
+
fan_out = shape[out_axis] * receptive_field_size
|
64 |
+
return fan_in, fan_out
|
65 |
+
|
66 |
+
def init(shape, dtype=dtype, device=device):
|
67 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
68 |
+
if mode == "fan_in":
|
69 |
+
denominator = fan_in
|
70 |
+
elif mode == "fan_out":
|
71 |
+
denominator = fan_out
|
72 |
+
elif mode == "fan_avg":
|
73 |
+
denominator = (fan_in + fan_out) / 2
|
74 |
+
else:
|
75 |
+
raise ValueError(
|
76 |
+
"invalid mode for variance scaling initializer: {}".format(mode))
|
77 |
+
variance = scale / denominator
|
78 |
+
if distribution == "normal":
|
79 |
+
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
80 |
+
elif distribution == "uniform":
|
81 |
+
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
|
82 |
+
else:
|
83 |
+
raise ValueError("invalid distribution for variance scaling initializer")
|
84 |
+
|
85 |
+
return init
|
86 |
+
|
87 |
+
|
88 |
+
def default_init(scale=1.):
|
89 |
+
"""The same initialization used in DDPM."""
|
90 |
+
scale = 1e-10 if scale == 0 else scale
|
91 |
+
return variance_scaling(scale, 'fan_avg', 'uniform')
|
92 |
+
|
93 |
+
|
94 |
+
class Dense(nn.Module):
|
95 |
+
"""Linear layer with `default_init`."""
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
|
100 |
+
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
|
101 |
+
"""1x1 convolution with DDPM initialization."""
|
102 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
103 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
104 |
+
nn.init.zeros_(conv.bias)
|
105 |
+
return conv
|
106 |
+
|
107 |
+
|
108 |
+
def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
109 |
+
"""3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
|
110 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
111 |
+
conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
|
112 |
+
dilation=dilation, padding=padding, kernel_size=3)
|
113 |
+
conv.weight.data *= init_scale
|
114 |
+
conv.bias.data *= init_scale
|
115 |
+
return conv
|
116 |
+
|
117 |
+
|
118 |
+
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
119 |
+
"""3x3 convolution with DDPM initialization."""
|
120 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
|
121 |
+
dilation=dilation, bias=bias)
|
122 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
123 |
+
nn.init.zeros_(conv.bias)
|
124 |
+
return conv
|
125 |
+
|
126 |
+
###########################################################################
|
127 |
+
# Functions below are ported over from the NCSNv1/NCSNv2 codebase:
|
128 |
+
# https://github.com/ermongroup/ncsn
|
129 |
+
# https://github.com/ermongroup/ncsnv2
|
130 |
+
###########################################################################
|
131 |
+
|
132 |
+
|
133 |
+
class CRPBlock(nn.Module):
|
134 |
+
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
|
135 |
+
super().__init__()
|
136 |
+
self.convs = nn.ModuleList()
|
137 |
+
for i in range(n_stages):
|
138 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
139 |
+
self.n_stages = n_stages
|
140 |
+
if maxpool:
|
141 |
+
self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
|
142 |
+
else:
|
143 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
144 |
+
|
145 |
+
self.act = act
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x = self.act(x)
|
149 |
+
path = x
|
150 |
+
for i in range(self.n_stages):
|
151 |
+
path = self.pool(path)
|
152 |
+
path = self.convs[i](path)
|
153 |
+
x = path + x
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class CondCRPBlock(nn.Module):
|
158 |
+
def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
159 |
+
super().__init__()
|
160 |
+
self.convs = nn.ModuleList()
|
161 |
+
self.norms = nn.ModuleList()
|
162 |
+
self.normalizer = normalizer
|
163 |
+
for i in range(n_stages):
|
164 |
+
self.norms.append(normalizer(features, num_classes, bias=True))
|
165 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
166 |
+
|
167 |
+
self.n_stages = n_stages
|
168 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
169 |
+
self.act = act
|
170 |
+
|
171 |
+
def forward(self, x, y):
|
172 |
+
x = self.act(x)
|
173 |
+
path = x
|
174 |
+
for i in range(self.n_stages):
|
175 |
+
path = self.norms[i](path, y)
|
176 |
+
path = self.pool(path)
|
177 |
+
path = self.convs[i](path)
|
178 |
+
|
179 |
+
x = path + x
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class RCUBlock(nn.Module):
|
184 |
+
def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
for i in range(n_blocks):
|
188 |
+
for j in range(n_stages):
|
189 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
190 |
+
|
191 |
+
self.stride = 1
|
192 |
+
self.n_blocks = n_blocks
|
193 |
+
self.n_stages = n_stages
|
194 |
+
self.act = act
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
for i in range(self.n_blocks):
|
198 |
+
residual = x
|
199 |
+
for j in range(self.n_stages):
|
200 |
+
x = self.act(x)
|
201 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
202 |
+
|
203 |
+
x += residual
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class CondRCUBlock(nn.Module):
|
208 |
+
def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
for i in range(n_blocks):
|
212 |
+
for j in range(n_stages):
|
213 |
+
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
|
214 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
215 |
+
|
216 |
+
self.stride = 1
|
217 |
+
self.n_blocks = n_blocks
|
218 |
+
self.n_stages = n_stages
|
219 |
+
self.act = act
|
220 |
+
self.normalizer = normalizer
|
221 |
+
|
222 |
+
def forward(self, x, y):
|
223 |
+
for i in range(self.n_blocks):
|
224 |
+
residual = x
|
225 |
+
for j in range(self.n_stages):
|
226 |
+
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
|
227 |
+
x = self.act(x)
|
228 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
229 |
+
|
230 |
+
x += residual
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
class MSFBlock(nn.Module):
|
235 |
+
def __init__(self, in_planes, features):
|
236 |
+
super().__init__()
|
237 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
238 |
+
self.convs = nn.ModuleList()
|
239 |
+
self.features = features
|
240 |
+
|
241 |
+
for i in range(len(in_planes)):
|
242 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
243 |
+
|
244 |
+
def forward(self, xs, shape):
|
245 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
246 |
+
for i in range(len(self.convs)):
|
247 |
+
h = self.convs[i](xs[i])
|
248 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
249 |
+
sums += h
|
250 |
+
return sums
|
251 |
+
|
252 |
+
|
253 |
+
class CondMSFBlock(nn.Module):
|
254 |
+
def __init__(self, in_planes, features, num_classes, normalizer):
|
255 |
+
super().__init__()
|
256 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
257 |
+
|
258 |
+
self.convs = nn.ModuleList()
|
259 |
+
self.norms = nn.ModuleList()
|
260 |
+
self.features = features
|
261 |
+
self.normalizer = normalizer
|
262 |
+
|
263 |
+
for i in range(len(in_planes)):
|
264 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
265 |
+
self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
|
266 |
+
|
267 |
+
def forward(self, xs, y, shape):
|
268 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
269 |
+
for i in range(len(self.convs)):
|
270 |
+
h = self.norms[i](xs[i], y)
|
271 |
+
h = self.convs[i](h)
|
272 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
273 |
+
sums += h
|
274 |
+
return sums
|
275 |
+
|
276 |
+
|
277 |
+
class RefineBlock(nn.Module):
|
278 |
+
def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
|
279 |
+
super().__init__()
|
280 |
+
|
281 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
282 |
+
self.n_blocks = n_blocks = len(in_planes)
|
283 |
+
|
284 |
+
self.adapt_convs = nn.ModuleList()
|
285 |
+
for i in range(n_blocks):
|
286 |
+
self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
|
287 |
+
|
288 |
+
self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
|
289 |
+
|
290 |
+
if not start:
|
291 |
+
self.msf = MSFBlock(in_planes, features)
|
292 |
+
|
293 |
+
self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
|
294 |
+
|
295 |
+
def forward(self, xs, output_shape):
|
296 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
297 |
+
hs = []
|
298 |
+
for i in range(len(xs)):
|
299 |
+
h = self.adapt_convs[i](xs[i])
|
300 |
+
hs.append(h)
|
301 |
+
|
302 |
+
if self.n_blocks > 1:
|
303 |
+
h = self.msf(hs, output_shape)
|
304 |
+
else:
|
305 |
+
h = hs[0]
|
306 |
+
|
307 |
+
h = self.crp(h)
|
308 |
+
h = self.output_convs(h)
|
309 |
+
|
310 |
+
return h
|
311 |
+
|
312 |
+
|
313 |
+
class CondRefineBlock(nn.Module):
|
314 |
+
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
|
315 |
+
super().__init__()
|
316 |
+
|
317 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
318 |
+
self.n_blocks = n_blocks = len(in_planes)
|
319 |
+
|
320 |
+
self.adapt_convs = nn.ModuleList()
|
321 |
+
for i in range(n_blocks):
|
322 |
+
self.adapt_convs.append(
|
323 |
+
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
|
324 |
+
)
|
325 |
+
|
326 |
+
self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
|
327 |
+
|
328 |
+
if not start:
|
329 |
+
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
|
330 |
+
|
331 |
+
self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
|
332 |
+
|
333 |
+
def forward(self, xs, y, output_shape):
|
334 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
335 |
+
hs = []
|
336 |
+
for i in range(len(xs)):
|
337 |
+
h = self.adapt_convs[i](xs[i], y)
|
338 |
+
hs.append(h)
|
339 |
+
|
340 |
+
if self.n_blocks > 1:
|
341 |
+
h = self.msf(hs, y, output_shape)
|
342 |
+
else:
|
343 |
+
h = hs[0]
|
344 |
+
|
345 |
+
h = self.crp(h, y)
|
346 |
+
h = self.output_convs(h, y)
|
347 |
+
|
348 |
+
return h
|
349 |
+
|
350 |
+
|
351 |
+
class ConvMeanPool(nn.Module):
|
352 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
|
353 |
+
super().__init__()
|
354 |
+
if not adjust_padding:
|
355 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
356 |
+
self.conv = conv
|
357 |
+
else:
|
358 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
359 |
+
|
360 |
+
self.conv = nn.Sequential(
|
361 |
+
nn.ZeroPad2d((1, 0, 1, 0)),
|
362 |
+
conv
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self, inputs):
|
366 |
+
output = self.conv(inputs)
|
367 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
368 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
369 |
+
return output
|
370 |
+
|
371 |
+
|
372 |
+
class MeanPoolConv(nn.Module):
|
373 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
374 |
+
super().__init__()
|
375 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
376 |
+
|
377 |
+
def forward(self, inputs):
|
378 |
+
output = inputs
|
379 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
380 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
381 |
+
return self.conv(output)
|
382 |
+
|
383 |
+
|
384 |
+
class UpsampleConv(nn.Module):
|
385 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
386 |
+
super().__init__()
|
387 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
388 |
+
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
|
389 |
+
|
390 |
+
def forward(self, inputs):
|
391 |
+
output = inputs
|
392 |
+
output = torch.cat([output, output, output, output], dim=1)
|
393 |
+
output = self.pixelshuffle(output)
|
394 |
+
return self.conv(output)
|
395 |
+
|
396 |
+
|
397 |
+
class ConditionalResidualBlock(nn.Module):
|
398 |
+
def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
|
399 |
+
normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
|
400 |
+
super().__init__()
|
401 |
+
self.non_linearity = act
|
402 |
+
self.input_dim = input_dim
|
403 |
+
self.output_dim = output_dim
|
404 |
+
self.resample = resample
|
405 |
+
self.normalization = normalization
|
406 |
+
if resample == 'down':
|
407 |
+
if dilation > 1:
|
408 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
409 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
410 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
411 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
412 |
+
else:
|
413 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
414 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
415 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
416 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
417 |
+
|
418 |
+
elif resample is None:
|
419 |
+
if dilation > 1:
|
420 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
421 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
422 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
423 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
424 |
+
else:
|
425 |
+
conv_shortcut = nn.Conv2d
|
426 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
427 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
428 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
429 |
+
else:
|
430 |
+
raise Exception('invalid resample value')
|
431 |
+
|
432 |
+
if output_dim != input_dim or resample is not None:
|
433 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
434 |
+
|
435 |
+
self.normalize1 = normalization(input_dim, num_classes)
|
436 |
+
|
437 |
+
def forward(self, x, y):
|
438 |
+
output = self.normalize1(x, y)
|
439 |
+
output = self.non_linearity(output)
|
440 |
+
output = self.conv1(output)
|
441 |
+
output = self.normalize2(output, y)
|
442 |
+
output = self.non_linearity(output)
|
443 |
+
output = self.conv2(output)
|
444 |
+
|
445 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
446 |
+
shortcut = x
|
447 |
+
else:
|
448 |
+
shortcut = self.shortcut(x)
|
449 |
+
|
450 |
+
return shortcut + output
|
451 |
+
|
452 |
+
|
453 |
+
class ResidualBlock(nn.Module):
|
454 |
+
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
|
455 |
+
normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
|
456 |
+
super().__init__()
|
457 |
+
self.non_linearity = act
|
458 |
+
self.input_dim = input_dim
|
459 |
+
self.output_dim = output_dim
|
460 |
+
self.resample = resample
|
461 |
+
self.normalization = normalization
|
462 |
+
if resample == 'down':
|
463 |
+
if dilation > 1:
|
464 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
465 |
+
self.normalize2 = normalization(input_dim)
|
466 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
467 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
468 |
+
else:
|
469 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
470 |
+
self.normalize2 = normalization(input_dim)
|
471 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
472 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
473 |
+
|
474 |
+
elif resample is None:
|
475 |
+
if dilation > 1:
|
476 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
477 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
478 |
+
self.normalize2 = normalization(output_dim)
|
479 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
480 |
+
else:
|
481 |
+
# conv_shortcut = nn.Conv2d ### Something wierd here.
|
482 |
+
conv_shortcut = partial(ncsn_conv1x1)
|
483 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
484 |
+
self.normalize2 = normalization(output_dim)
|
485 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
486 |
+
else:
|
487 |
+
raise Exception('invalid resample value')
|
488 |
+
|
489 |
+
if output_dim != input_dim or resample is not None:
|
490 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
491 |
+
|
492 |
+
self.normalize1 = normalization(input_dim)
|
493 |
+
|
494 |
+
def forward(self, x):
|
495 |
+
output = self.normalize1(x)
|
496 |
+
output = self.non_linearity(output)
|
497 |
+
output = self.conv1(output)
|
498 |
+
output = self.normalize2(output)
|
499 |
+
output = self.non_linearity(output)
|
500 |
+
output = self.conv2(output)
|
501 |
+
|
502 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
503 |
+
shortcut = x
|
504 |
+
else:
|
505 |
+
shortcut = self.shortcut(x)
|
506 |
+
|
507 |
+
return shortcut + output
|
508 |
+
|
509 |
+
|
510 |
+
###########################################################################
|
511 |
+
# Functions below are ported over from the DDPM codebase:
|
512 |
+
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
|
513 |
+
###########################################################################
|
514 |
+
|
515 |
+
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
516 |
+
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
517 |
+
half_dim = embedding_dim // 2
|
518 |
+
# magic number 10000 is from transformers
|
519 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
520 |
+
# emb = math.log(2.) / (half_dim - 1)
|
521 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
522 |
+
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
|
523 |
+
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
|
524 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
525 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
526 |
+
if embedding_dim % 2 == 1: # zero pad
|
527 |
+
emb = F.pad(emb, (0, 1), mode='constant')
|
528 |
+
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
529 |
+
return emb
|
530 |
+
|
531 |
+
|
532 |
+
def _einsum(a, b, c, x, y):
|
533 |
+
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
|
534 |
+
return torch.einsum(einsum_str, x, y)
|
535 |
+
|
536 |
+
|
537 |
+
def contract_inner(x, y):
|
538 |
+
"""tensordot(x, y, 1)."""
|
539 |
+
x_chars = list(string.ascii_lowercase[:len(x.shape)])
|
540 |
+
y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
|
541 |
+
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
542 |
+
out_chars = x_chars[:-1] + y_chars[1:]
|
543 |
+
return _einsum(x_chars, y_chars, out_chars, x, y)
|
544 |
+
|
545 |
+
|
546 |
+
class NIN(nn.Module):
|
547 |
+
def __init__(self, in_dim, num_units, init_scale=0.1):
|
548 |
+
super().__init__()
|
549 |
+
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
550 |
+
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
x = x.permute(0, 2, 3, 1)
|
554 |
+
y = contract_inner(x, self.W) + self.b
|
555 |
+
return y.permute(0, 3, 1, 2)
|
556 |
+
|
557 |
+
|
558 |
+
class AttnBlock(nn.Module):
|
559 |
+
"""Channel-wise self-attention block."""
|
560 |
+
def __init__(self, channels):
|
561 |
+
super().__init__()
|
562 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
563 |
+
self.NIN_0 = NIN(channels, channels)
|
564 |
+
self.NIN_1 = NIN(channels, channels)
|
565 |
+
self.NIN_2 = NIN(channels, channels)
|
566 |
+
self.NIN_3 = NIN(channels, channels, init_scale=0.)
|
567 |
+
|
568 |
+
def forward(self, x):
|
569 |
+
B, C, H, W = x.shape
|
570 |
+
h = self.GroupNorm_0(x)
|
571 |
+
q = self.NIN_0(h)
|
572 |
+
k = self.NIN_1(h)
|
573 |
+
v = self.NIN_2(h)
|
574 |
+
|
575 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
576 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
577 |
+
w = F.softmax(w, dim=-1)
|
578 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
579 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
580 |
+
h = self.NIN_3(h)
|
581 |
+
return x + h
|
582 |
+
|
583 |
+
|
584 |
+
class Upsample(nn.Module):
|
585 |
+
def __init__(self, channels, with_conv=False):
|
586 |
+
super().__init__()
|
587 |
+
if with_conv:
|
588 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels)
|
589 |
+
self.with_conv = with_conv
|
590 |
+
|
591 |
+
def forward(self, x):
|
592 |
+
B, C, H, W = x.shape
|
593 |
+
h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
|
594 |
+
if self.with_conv:
|
595 |
+
h = self.Conv_0(h)
|
596 |
+
return h
|
597 |
+
|
598 |
+
|
599 |
+
class Downsample(nn.Module):
|
600 |
+
def __init__(self, channels, with_conv=False):
|
601 |
+
super().__init__()
|
602 |
+
if with_conv:
|
603 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
|
604 |
+
self.with_conv = with_conv
|
605 |
+
|
606 |
+
def forward(self, x):
|
607 |
+
B, C, H, W = x.shape
|
608 |
+
# Emulate 'SAME' padding
|
609 |
+
if self.with_conv:
|
610 |
+
x = F.pad(x, (0, 1, 0, 1))
|
611 |
+
x = self.Conv_0(x)
|
612 |
+
else:
|
613 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
|
614 |
+
|
615 |
+
assert x.shape == (B, C, H // 2, W // 2)
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
class ResnetBlockDDPM(nn.Module):
|
620 |
+
"""The ResNet Blocks used in DDPM."""
|
621 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
|
622 |
+
super().__init__()
|
623 |
+
if out_ch is None:
|
624 |
+
out_ch = in_ch
|
625 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
|
626 |
+
self.act = act
|
627 |
+
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
|
628 |
+
if temb_dim is not None:
|
629 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
630 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
631 |
+
nn.init.zeros_(self.Dense_0.bias)
|
632 |
+
|
633 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
|
634 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
635 |
+
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
|
636 |
+
if in_ch != out_ch:
|
637 |
+
if conv_shortcut:
|
638 |
+
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
|
639 |
+
else:
|
640 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
641 |
+
self.out_ch = out_ch
|
642 |
+
self.in_ch = in_ch
|
643 |
+
self.conv_shortcut = conv_shortcut
|
644 |
+
|
645 |
+
def forward(self, x, temb=None):
|
646 |
+
B, C, H, W = x.shape
|
647 |
+
assert C == self.in_ch
|
648 |
+
out_ch = self.out_ch if self.out_ch else self.in_ch
|
649 |
+
h = self.act(self.GroupNorm_0(x))
|
650 |
+
h = self.Conv_0(h)
|
651 |
+
# Add bias to each feature map conditioned on the time embedding
|
652 |
+
if temb is not None:
|
653 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
654 |
+
h = self.act(self.GroupNorm_1(h))
|
655 |
+
h = self.Dropout_0(h)
|
656 |
+
h = self.Conv_1(h)
|
657 |
+
if C != out_ch:
|
658 |
+
if self.conv_shortcut:
|
659 |
+
x = self.Conv_2(x)
|
660 |
+
else:
|
661 |
+
x = self.NIN_0(x)
|
662 |
+
return x + h
|
sgmse/backbones/ncsnpp_utils/layerspp.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Layers for defining NCSN++.
|
18 |
+
"""
|
19 |
+
from . import layers
|
20 |
+
from . import up_or_down_sampling
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
conv1x1 = layers.ddpm_conv1x1
|
27 |
+
conv3x3 = layers.ddpm_conv3x3
|
28 |
+
NIN = layers.NIN
|
29 |
+
default_init = layers.default_init
|
30 |
+
|
31 |
+
|
32 |
+
class GaussianFourierProjection(nn.Module):
|
33 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
34 |
+
|
35 |
+
def __init__(self, embedding_size=256, scale=1.0):
|
36 |
+
super().__init__()
|
37 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
41 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
42 |
+
|
43 |
+
|
44 |
+
class Combine(nn.Module):
|
45 |
+
"""Combine information from skip connections."""
|
46 |
+
|
47 |
+
def __init__(self, dim1, dim2, method='cat'):
|
48 |
+
super().__init__()
|
49 |
+
self.Conv_0 = conv1x1(dim1, dim2)
|
50 |
+
self.method = method
|
51 |
+
|
52 |
+
def forward(self, x, y):
|
53 |
+
h = self.Conv_0(x)
|
54 |
+
if self.method == 'cat':
|
55 |
+
return torch.cat([h, y], dim=1)
|
56 |
+
elif self.method == 'sum':
|
57 |
+
return h + y
|
58 |
+
else:
|
59 |
+
raise ValueError(f'Method {self.method} not recognized.')
|
60 |
+
|
61 |
+
|
62 |
+
class AttnBlockpp(nn.Module):
|
63 |
+
"""Channel-wise self-attention block. Modified from DDPM."""
|
64 |
+
|
65 |
+
def __init__(self, channels, skip_rescale=False, init_scale=0.):
|
66 |
+
super().__init__()
|
67 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
|
68 |
+
eps=1e-6)
|
69 |
+
self.NIN_0 = NIN(channels, channels)
|
70 |
+
self.NIN_1 = NIN(channels, channels)
|
71 |
+
self.NIN_2 = NIN(channels, channels)
|
72 |
+
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
73 |
+
self.skip_rescale = skip_rescale
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
B, C, H, W = x.shape
|
77 |
+
h = self.GroupNorm_0(x)
|
78 |
+
q = self.NIN_0(h)
|
79 |
+
k = self.NIN_1(h)
|
80 |
+
v = self.NIN_2(h)
|
81 |
+
|
82 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
83 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
84 |
+
w = F.softmax(w, dim=-1)
|
85 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
86 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
87 |
+
h = self.NIN_3(h)
|
88 |
+
if not self.skip_rescale:
|
89 |
+
return x + h
|
90 |
+
else:
|
91 |
+
return (x + h) / np.sqrt(2.)
|
92 |
+
|
93 |
+
|
94 |
+
class Upsample(nn.Module):
|
95 |
+
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
96 |
+
fir_kernel=(1, 3, 3, 1)):
|
97 |
+
super().__init__()
|
98 |
+
out_ch = out_ch if out_ch else in_ch
|
99 |
+
if not fir:
|
100 |
+
if with_conv:
|
101 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
102 |
+
else:
|
103 |
+
if with_conv:
|
104 |
+
self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
|
105 |
+
kernel=3, up=True,
|
106 |
+
resample_kernel=fir_kernel,
|
107 |
+
use_bias=True,
|
108 |
+
kernel_init=default_init())
|
109 |
+
self.fir = fir
|
110 |
+
self.with_conv = with_conv
|
111 |
+
self.fir_kernel = fir_kernel
|
112 |
+
self.out_ch = out_ch
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
B, C, H, W = x.shape
|
116 |
+
if not self.fir:
|
117 |
+
h = F.interpolate(x, (H * 2, W * 2), 'nearest')
|
118 |
+
if self.with_conv:
|
119 |
+
h = self.Conv_0(h)
|
120 |
+
else:
|
121 |
+
if not self.with_conv:
|
122 |
+
h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
|
123 |
+
else:
|
124 |
+
h = self.Conv2d_0(x)
|
125 |
+
|
126 |
+
return h
|
127 |
+
|
128 |
+
|
129 |
+
class Downsample(nn.Module):
|
130 |
+
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
131 |
+
fir_kernel=(1, 3, 3, 1)):
|
132 |
+
super().__init__()
|
133 |
+
out_ch = out_ch if out_ch else in_ch
|
134 |
+
if not fir:
|
135 |
+
if with_conv:
|
136 |
+
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
137 |
+
else:
|
138 |
+
if with_conv:
|
139 |
+
self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
|
140 |
+
kernel=3, down=True,
|
141 |
+
resample_kernel=fir_kernel,
|
142 |
+
use_bias=True,
|
143 |
+
kernel_init=default_init())
|
144 |
+
self.fir = fir
|
145 |
+
self.fir_kernel = fir_kernel
|
146 |
+
self.with_conv = with_conv
|
147 |
+
self.out_ch = out_ch
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
B, C, H, W = x.shape
|
151 |
+
if not self.fir:
|
152 |
+
if self.with_conv:
|
153 |
+
x = F.pad(x, (0, 1, 0, 1))
|
154 |
+
x = self.Conv_0(x)
|
155 |
+
else:
|
156 |
+
x = F.avg_pool2d(x, 2, stride=2)
|
157 |
+
else:
|
158 |
+
if not self.with_conv:
|
159 |
+
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
|
160 |
+
else:
|
161 |
+
x = self.Conv2d_0(x)
|
162 |
+
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class ResnetBlockDDPMpp(nn.Module):
|
167 |
+
"""ResBlock adapted from DDPM."""
|
168 |
+
|
169 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
|
170 |
+
dropout=0.1, skip_rescale=False, init_scale=0.):
|
171 |
+
super().__init__()
|
172 |
+
out_ch = out_ch if out_ch else in_ch
|
173 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
174 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
175 |
+
if temb_dim is not None:
|
176 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
177 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
178 |
+
nn.init.zeros_(self.Dense_0.bias)
|
179 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
180 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
181 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
182 |
+
if in_ch != out_ch:
|
183 |
+
if conv_shortcut:
|
184 |
+
self.Conv_2 = conv3x3(in_ch, out_ch)
|
185 |
+
else:
|
186 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
187 |
+
|
188 |
+
self.skip_rescale = skip_rescale
|
189 |
+
self.act = act
|
190 |
+
self.out_ch = out_ch
|
191 |
+
self.conv_shortcut = conv_shortcut
|
192 |
+
|
193 |
+
def forward(self, x, temb=None):
|
194 |
+
h = self.act(self.GroupNorm_0(x))
|
195 |
+
h = self.Conv_0(h)
|
196 |
+
if temb is not None:
|
197 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
198 |
+
h = self.act(self.GroupNorm_1(h))
|
199 |
+
h = self.Dropout_0(h)
|
200 |
+
h = self.Conv_1(h)
|
201 |
+
if x.shape[1] != self.out_ch:
|
202 |
+
if self.conv_shortcut:
|
203 |
+
x = self.Conv_2(x)
|
204 |
+
else:
|
205 |
+
x = self.NIN_0(x)
|
206 |
+
if not self.skip_rescale:
|
207 |
+
return x + h
|
208 |
+
else:
|
209 |
+
return (x + h) / np.sqrt(2.)
|
210 |
+
|
211 |
+
|
212 |
+
class ResnetBlockBigGANpp(nn.Module):
|
213 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
|
214 |
+
dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
|
215 |
+
skip_rescale=True, init_scale=0.):
|
216 |
+
super().__init__()
|
217 |
+
|
218 |
+
out_ch = out_ch if out_ch else in_ch
|
219 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
220 |
+
self.up = up
|
221 |
+
self.down = down
|
222 |
+
self.fir = fir
|
223 |
+
self.fir_kernel = fir_kernel
|
224 |
+
|
225 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
226 |
+
if temb_dim is not None:
|
227 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
228 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
229 |
+
nn.init.zeros_(self.Dense_0.bias)
|
230 |
+
|
231 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
232 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
233 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
234 |
+
if in_ch != out_ch or up or down:
|
235 |
+
self.Conv_2 = conv1x1(in_ch, out_ch)
|
236 |
+
|
237 |
+
self.skip_rescale = skip_rescale
|
238 |
+
self.act = act
|
239 |
+
self.in_ch = in_ch
|
240 |
+
self.out_ch = out_ch
|
241 |
+
|
242 |
+
def forward(self, x, temb=None):
|
243 |
+
h = self.act(self.GroupNorm_0(x))
|
244 |
+
|
245 |
+
if self.up:
|
246 |
+
if self.fir:
|
247 |
+
h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
|
248 |
+
x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
|
249 |
+
else:
|
250 |
+
h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
|
251 |
+
x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
|
252 |
+
elif self.down:
|
253 |
+
if self.fir:
|
254 |
+
h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
|
255 |
+
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
|
256 |
+
else:
|
257 |
+
h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
|
258 |
+
x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
|
259 |
+
|
260 |
+
h = self.Conv_0(h)
|
261 |
+
# Add bias to each feature map conditioned on the time embedding
|
262 |
+
if temb is not None:
|
263 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
264 |
+
h = self.act(self.GroupNorm_1(h))
|
265 |
+
h = self.Dropout_0(h)
|
266 |
+
h = self.Conv_1(h)
|
267 |
+
|
268 |
+
if self.in_ch != self.out_ch or self.up or self.down:
|
269 |
+
x = self.Conv_2(x)
|
270 |
+
|
271 |
+
if not self.skip_rescale:
|
272 |
+
return x + h
|
273 |
+
else:
|
274 |
+
return (x + h) / np.sqrt(2.)
|
sgmse/backbones/ncsnpp_utils/normalization.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Normalization layers."""
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch
|
19 |
+
import functools
|
20 |
+
|
21 |
+
|
22 |
+
def get_normalization(config, conditional=False):
|
23 |
+
"""Obtain normalization modules from the config file."""
|
24 |
+
norm = config.model.normalization
|
25 |
+
if conditional:
|
26 |
+
if norm == 'InstanceNorm++':
|
27 |
+
return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError(f'{norm} not implemented yet.')
|
30 |
+
else:
|
31 |
+
if norm == 'InstanceNorm':
|
32 |
+
return nn.InstanceNorm2d
|
33 |
+
elif norm == 'InstanceNorm++':
|
34 |
+
return InstanceNorm2dPlus
|
35 |
+
elif norm == 'VarianceNorm':
|
36 |
+
return VarianceNorm2d
|
37 |
+
elif norm == 'GroupNorm':
|
38 |
+
return nn.GroupNorm
|
39 |
+
else:
|
40 |
+
raise ValueError('Unknown normalization: %s' % norm)
|
41 |
+
|
42 |
+
|
43 |
+
class ConditionalBatchNorm2d(nn.Module):
|
44 |
+
def __init__(self, num_features, num_classes, bias=True):
|
45 |
+
super().__init__()
|
46 |
+
self.num_features = num_features
|
47 |
+
self.bias = bias
|
48 |
+
self.bn = nn.BatchNorm2d(num_features, affine=False)
|
49 |
+
if self.bias:
|
50 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
51 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
52 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
53 |
+
else:
|
54 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
55 |
+
self.embed.weight.data.uniform_()
|
56 |
+
|
57 |
+
def forward(self, x, y):
|
58 |
+
out = self.bn(x)
|
59 |
+
if self.bias:
|
60 |
+
gamma, beta = self.embed(y).chunk(2, dim=1)
|
61 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
|
62 |
+
else:
|
63 |
+
gamma = self.embed(y)
|
64 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class ConditionalInstanceNorm2d(nn.Module):
|
69 |
+
def __init__(self, num_features, num_classes, bias=True):
|
70 |
+
super().__init__()
|
71 |
+
self.num_features = num_features
|
72 |
+
self.bias = bias
|
73 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
74 |
+
if bias:
|
75 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
76 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
77 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
78 |
+
else:
|
79 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
80 |
+
self.embed.weight.data.uniform_()
|
81 |
+
|
82 |
+
def forward(self, x, y):
|
83 |
+
h = self.instance_norm(x)
|
84 |
+
if self.bias:
|
85 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
86 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
87 |
+
else:
|
88 |
+
gamma = self.embed(y)
|
89 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
class ConditionalVarianceNorm2d(nn.Module):
|
94 |
+
def __init__(self, num_features, num_classes, bias=False):
|
95 |
+
super().__init__()
|
96 |
+
self.num_features = num_features
|
97 |
+
self.bias = bias
|
98 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
99 |
+
self.embed.weight.data.normal_(1, 0.02)
|
100 |
+
|
101 |
+
def forward(self, x, y):
|
102 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
103 |
+
h = x / torch.sqrt(vars + 1e-5)
|
104 |
+
|
105 |
+
gamma = self.embed(y)
|
106 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class VarianceNorm2d(nn.Module):
|
111 |
+
def __init__(self, num_features, bias=False):
|
112 |
+
super().__init__()
|
113 |
+
self.num_features = num_features
|
114 |
+
self.bias = bias
|
115 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
116 |
+
self.alpha.data.normal_(1, 0.02)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
120 |
+
h = x / torch.sqrt(vars + 1e-5)
|
121 |
+
|
122 |
+
out = self.alpha.view(-1, self.num_features, 1, 1) * h
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
class ConditionalNoneNorm2d(nn.Module):
|
127 |
+
def __init__(self, num_features, num_classes, bias=True):
|
128 |
+
super().__init__()
|
129 |
+
self.num_features = num_features
|
130 |
+
self.bias = bias
|
131 |
+
if bias:
|
132 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
133 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
134 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
135 |
+
else:
|
136 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
137 |
+
self.embed.weight.data.uniform_()
|
138 |
+
|
139 |
+
def forward(self, x, y):
|
140 |
+
if self.bias:
|
141 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
142 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
|
143 |
+
else:
|
144 |
+
gamma = self.embed(y)
|
145 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class NoneNorm2d(nn.Module):
|
150 |
+
def __init__(self, num_features, bias=True):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class InstanceNorm2dPlus(nn.Module):
|
158 |
+
def __init__(self, num_features, bias=True):
|
159 |
+
super().__init__()
|
160 |
+
self.num_features = num_features
|
161 |
+
self.bias = bias
|
162 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
163 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
164 |
+
self.gamma = nn.Parameter(torch.zeros(num_features))
|
165 |
+
self.alpha.data.normal_(1, 0.02)
|
166 |
+
self.gamma.data.normal_(1, 0.02)
|
167 |
+
if bias:
|
168 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
means = torch.mean(x, dim=(2, 3))
|
172 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
173 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
174 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
175 |
+
h = self.instance_norm(x)
|
176 |
+
|
177 |
+
if self.bias:
|
178 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
179 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
|
180 |
+
else:
|
181 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
182 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
class ConditionalInstanceNorm2dPlus(nn.Module):
|
187 |
+
def __init__(self, num_features, num_classes, bias=True):
|
188 |
+
super().__init__()
|
189 |
+
self.num_features = num_features
|
190 |
+
self.bias = bias
|
191 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
192 |
+
if bias:
|
193 |
+
self.embed = nn.Embedding(num_classes, num_features * 3)
|
194 |
+
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
|
195 |
+
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
|
196 |
+
else:
|
197 |
+
self.embed = nn.Embedding(num_classes, 2 * num_features)
|
198 |
+
self.embed.weight.data.normal_(1, 0.02)
|
199 |
+
|
200 |
+
def forward(self, x, y):
|
201 |
+
means = torch.mean(x, dim=(2, 3))
|
202 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
203 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
204 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
205 |
+
h = self.instance_norm(x)
|
206 |
+
|
207 |
+
if self.bias:
|
208 |
+
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
|
209 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
210 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
211 |
+
else:
|
212 |
+
gamma, alpha = self.embed(y).chunk(2, dim=-1)
|
213 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
214 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
215 |
+
return out
|
sgmse/backbones/ncsnpp_utils/op/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .upfirdn2d import upfirdn2d
|
sgmse/backbones/ncsnpp_utils/op/fused_act.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
|
10 |
+
module_path = os.path.dirname(__file__)
|
11 |
+
fused = load(
|
12 |
+
"fused",
|
13 |
+
sources=[
|
14 |
+
os.path.join(module_path, "fused_bias_act.cpp"),
|
15 |
+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
16 |
+
],
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
23 |
+
ctx.save_for_backward(out)
|
24 |
+
ctx.negative_slope = negative_slope
|
25 |
+
ctx.scale = scale
|
26 |
+
|
27 |
+
empty = grad_output.new_empty(0)
|
28 |
+
|
29 |
+
grad_input = fused.fused_bias_act(
|
30 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
31 |
+
)
|
32 |
+
|
33 |
+
dim = [0]
|
34 |
+
|
35 |
+
if grad_input.ndim > 2:
|
36 |
+
dim += list(range(2, grad_input.ndim))
|
37 |
+
|
38 |
+
grad_bias = grad_input.sum(dim).detach()
|
39 |
+
|
40 |
+
return grad_input, grad_bias
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
44 |
+
out, = ctx.saved_tensors
|
45 |
+
gradgrad_out = fused.fused_bias_act(
|
46 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
47 |
+
)
|
48 |
+
|
49 |
+
return gradgrad_out, None, None, None
|
50 |
+
|
51 |
+
|
52 |
+
class FusedLeakyReLUFunction(Function):
|
53 |
+
@staticmethod
|
54 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
55 |
+
empty = input.new_empty(0)
|
56 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
57 |
+
ctx.save_for_backward(out)
|
58 |
+
ctx.negative_slope = negative_slope
|
59 |
+
ctx.scale = scale
|
60 |
+
|
61 |
+
return out
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, grad_output):
|
65 |
+
out, = ctx.saved_tensors
|
66 |
+
|
67 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
68 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
69 |
+
)
|
70 |
+
|
71 |
+
return grad_input, grad_bias, None, None
|
72 |
+
|
73 |
+
|
74 |
+
class FusedLeakyReLU(nn.Module):
|
75 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
79 |
+
self.negative_slope = negative_slope
|
80 |
+
self.scale = scale
|
81 |
+
|
82 |
+
def forward(self, input):
|
83 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
84 |
+
|
85 |
+
|
86 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
87 |
+
if input.device.type == "cpu":
|
88 |
+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
89 |
+
return (
|
90 |
+
F.leaky_relu(
|
91 |
+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
92 |
+
)
|
93 |
+
* scale
|
94 |
+
)
|
95 |
+
|
96 |
+
else:
|
97 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.utils.cpp_extension import load
|
7 |
+
|
8 |
+
|
9 |
+
module_path = os.path.dirname(__file__)
|
10 |
+
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
upfirdn2d_op = load(
|
13 |
+
"upfirdn2d",
|
14 |
+
sources=[
|
15 |
+
os.path.join(module_path, "upfirdn2d.cpp"),
|
16 |
+
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
17 |
+
],
|
18 |
+
)
|
19 |
+
else:
|
20 |
+
upfirdn2d_op = None
|
21 |
+
|
22 |
+
class UpFirDn2dBackward(Function):
|
23 |
+
@staticmethod
|
24 |
+
def forward(
|
25 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
26 |
+
):
|
27 |
+
|
28 |
+
up_x, up_y = up
|
29 |
+
down_x, down_y = down
|
30 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
31 |
+
|
32 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
33 |
+
|
34 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
35 |
+
grad_output,
|
36 |
+
grad_kernel,
|
37 |
+
down_x,
|
38 |
+
down_y,
|
39 |
+
up_x,
|
40 |
+
up_y,
|
41 |
+
g_pad_x0,
|
42 |
+
g_pad_x1,
|
43 |
+
g_pad_y0,
|
44 |
+
g_pad_y1,
|
45 |
+
)
|
46 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
47 |
+
|
48 |
+
ctx.save_for_backward(kernel)
|
49 |
+
|
50 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
51 |
+
|
52 |
+
ctx.up_x = up_x
|
53 |
+
ctx.up_y = up_y
|
54 |
+
ctx.down_x = down_x
|
55 |
+
ctx.down_y = down_y
|
56 |
+
ctx.pad_x0 = pad_x0
|
57 |
+
ctx.pad_x1 = pad_x1
|
58 |
+
ctx.pad_y0 = pad_y0
|
59 |
+
ctx.pad_y1 = pad_y1
|
60 |
+
ctx.in_size = in_size
|
61 |
+
ctx.out_size = out_size
|
62 |
+
|
63 |
+
return grad_input
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def backward(ctx, gradgrad_input):
|
67 |
+
kernel, = ctx.saved_tensors
|
68 |
+
|
69 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
70 |
+
|
71 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
72 |
+
gradgrad_input,
|
73 |
+
kernel,
|
74 |
+
ctx.up_x,
|
75 |
+
ctx.up_y,
|
76 |
+
ctx.down_x,
|
77 |
+
ctx.down_y,
|
78 |
+
ctx.pad_x0,
|
79 |
+
ctx.pad_x1,
|
80 |
+
ctx.pad_y0,
|
81 |
+
ctx.pad_y1,
|
82 |
+
)
|
83 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
84 |
+
gradgrad_out = gradgrad_out.view(
|
85 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
86 |
+
)
|
87 |
+
|
88 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
89 |
+
|
90 |
+
|
91 |
+
class UpFirDn2d(Function):
|
92 |
+
@staticmethod
|
93 |
+
def forward(ctx, input, kernel, up, down, pad):
|
94 |
+
up_x, up_y = up
|
95 |
+
down_x, down_y = down
|
96 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
97 |
+
|
98 |
+
kernel_h, kernel_w = kernel.shape
|
99 |
+
batch, channel, in_h, in_w = input.shape
|
100 |
+
ctx.in_size = input.shape
|
101 |
+
|
102 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
103 |
+
|
104 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
105 |
+
|
106 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
107 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
108 |
+
ctx.out_size = (out_h, out_w)
|
109 |
+
|
110 |
+
ctx.up = (up_x, up_y)
|
111 |
+
ctx.down = (down_x, down_y)
|
112 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
113 |
+
|
114 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
115 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
116 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
117 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
118 |
+
|
119 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
120 |
+
|
121 |
+
out = upfirdn2d_op.upfirdn2d(
|
122 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
123 |
+
)
|
124 |
+
# out = out.view(major, out_h, out_w, minor)
|
125 |
+
out = out.view(-1, channel, out_h, out_w)
|
126 |
+
|
127 |
+
return out
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def backward(ctx, grad_output):
|
131 |
+
kernel, grad_kernel = ctx.saved_tensors
|
132 |
+
|
133 |
+
grad_input = UpFirDn2dBackward.apply(
|
134 |
+
grad_output,
|
135 |
+
kernel,
|
136 |
+
grad_kernel,
|
137 |
+
ctx.up,
|
138 |
+
ctx.down,
|
139 |
+
ctx.pad,
|
140 |
+
ctx.g_pad,
|
141 |
+
ctx.in_size,
|
142 |
+
ctx.out_size,
|
143 |
+
)
|
144 |
+
|
145 |
+
return grad_input, None, None, None, None
|
146 |
+
|
147 |
+
|
148 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
149 |
+
if input.device.type == "cpu":
|
150 |
+
out = upfirdn2d_native(
|
151 |
+
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
152 |
+
)
|
153 |
+
|
154 |
+
else:
|
155 |
+
out = UpFirDn2d.apply(
|
156 |
+
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
157 |
+
)
|
158 |
+
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
def upfirdn2d_native(
|
163 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
164 |
+
):
|
165 |
+
_, channel, in_h, in_w = input.shape
|
166 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
167 |
+
|
168 |
+
_, in_h, in_w, minor = input.shape
|
169 |
+
kernel_h, kernel_w = kernel.shape
|
170 |
+
|
171 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
172 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
173 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
174 |
+
|
175 |
+
out = F.pad(
|
176 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
177 |
+
)
|
178 |
+
out = out[
|
179 |
+
:,
|
180 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
181 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
182 |
+
:,
|
183 |
+
]
|
184 |
+
|
185 |
+
out = out.permute(0, 3, 1, 2)
|
186 |
+
out = out.reshape(
|
187 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
188 |
+
)
|
189 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
190 |
+
out = F.conv2d(out, w)
|
191 |
+
out = out.reshape(
|
192 |
+
-1,
|
193 |
+
minor,
|
194 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
195 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
196 |
+
)
|
197 |
+
out = out.permute(0, 2, 3, 1)
|
198 |
+
out = out[:, ::down_y, ::down_x, :]
|
199 |
+
|
200 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
201 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
202 |
+
|
203 |
+
return out.view(-1, channel, out_h, out_w)
|
sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
18 |
+
int c = a / b;
|
19 |
+
|
20 |
+
if (c * b > a) {
|
21 |
+
c--;
|
22 |
+
}
|
23 |
+
|
24 |
+
return c;
|
25 |
+
}
|
26 |
+
|
27 |
+
struct UpFirDn2DKernelParams {
|
28 |
+
int up_x;
|
29 |
+
int up_y;
|
30 |
+
int down_x;
|
31 |
+
int down_y;
|
32 |
+
int pad_x0;
|
33 |
+
int pad_x1;
|
34 |
+
int pad_y0;
|
35 |
+
int pad_y1;
|
36 |
+
|
37 |
+
int major_dim;
|
38 |
+
int in_h;
|
39 |
+
int in_w;
|
40 |
+
int minor_dim;
|
41 |
+
int kernel_h;
|
42 |
+
int kernel_w;
|
43 |
+
int out_h;
|
44 |
+
int out_w;
|
45 |
+
int loop_major;
|
46 |
+
int loop_x;
|
47 |
+
};
|
48 |
+
|
49 |
+
template <typename scalar_t>
|
50 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
51 |
+
const scalar_t *kernel,
|
52 |
+
const UpFirDn2DKernelParams p) {
|
53 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
54 |
+
int out_y = minor_idx / p.minor_dim;
|
55 |
+
minor_idx -= out_y * p.minor_dim;
|
56 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
57 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
58 |
+
|
59 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
60 |
+
major_idx_base >= p.major_dim) {
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
|
64 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
65 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
66 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
67 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
68 |
+
|
69 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
70 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
71 |
+
loop_major++, major_idx++) {
|
72 |
+
for (int loop_x = 0, out_x = out_x_base;
|
73 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
74 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
75 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
76 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
77 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
78 |
+
|
79 |
+
const scalar_t *x_p =
|
80 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
81 |
+
minor_idx];
|
82 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
83 |
+
int x_px = p.minor_dim;
|
84 |
+
int k_px = -p.up_x;
|
85 |
+
int x_py = p.in_w * p.minor_dim;
|
86 |
+
int k_py = -p.up_y * p.kernel_w;
|
87 |
+
|
88 |
+
scalar_t v = 0.0f;
|
89 |
+
|
90 |
+
for (int y = 0; y < h; y++) {
|
91 |
+
for (int x = 0; x < w; x++) {
|
92 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
93 |
+
x_p += x_px;
|
94 |
+
k_p += k_px;
|
95 |
+
}
|
96 |
+
|
97 |
+
x_p += x_py - w * x_px;
|
98 |
+
k_p += k_py - w * k_px;
|
99 |
+
}
|
100 |
+
|
101 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
102 |
+
minor_idx] = v;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
108 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
109 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
110 |
+
const scalar_t *kernel,
|
111 |
+
const UpFirDn2DKernelParams p) {
|
112 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
113 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
114 |
+
|
115 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
116 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
117 |
+
|
118 |
+
int minor_idx = blockIdx.x;
|
119 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
120 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
121 |
+
tile_out_y *= tile_out_h;
|
122 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
123 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
124 |
+
|
125 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
126 |
+
major_idx_base >= p.major_dim) {
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
131 |
+
tap_idx += blockDim.x) {
|
132 |
+
int ky = tap_idx / kernel_w;
|
133 |
+
int kx = tap_idx - ky * kernel_w;
|
134 |
+
scalar_t v = 0.0;
|
135 |
+
|
136 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
137 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
138 |
+
}
|
139 |
+
|
140 |
+
sk[ky][kx] = v;
|
141 |
+
}
|
142 |
+
|
143 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
144 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
145 |
+
loop_major++, major_idx++) {
|
146 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
147 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
148 |
+
loop_x++, tile_out_x += tile_out_w) {
|
149 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
150 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
151 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
152 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
153 |
+
|
154 |
+
__syncthreads();
|
155 |
+
|
156 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
157 |
+
in_idx += blockDim.x) {
|
158 |
+
int rel_in_y = in_idx / tile_in_w;
|
159 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
160 |
+
int in_x = rel_in_x + tile_in_x;
|
161 |
+
int in_y = rel_in_y + tile_in_y;
|
162 |
+
|
163 |
+
scalar_t v = 0.0;
|
164 |
+
|
165 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
166 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
167 |
+
p.minor_dim +
|
168 |
+
minor_idx];
|
169 |
+
}
|
170 |
+
|
171 |
+
sx[rel_in_y][rel_in_x] = v;
|
172 |
+
}
|
173 |
+
|
174 |
+
__syncthreads();
|
175 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
176 |
+
out_idx += blockDim.x) {
|
177 |
+
int rel_out_y = out_idx / tile_out_w;
|
178 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
179 |
+
int out_x = rel_out_x + tile_out_x;
|
180 |
+
int out_y = rel_out_y + tile_out_y;
|
181 |
+
|
182 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
183 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
184 |
+
int in_x = floor_div(mid_x, up_x);
|
185 |
+
int in_y = floor_div(mid_y, up_y);
|
186 |
+
int rel_in_x = in_x - tile_in_x;
|
187 |
+
int rel_in_y = in_y - tile_in_y;
|
188 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
189 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
190 |
+
|
191 |
+
scalar_t v = 0.0;
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
195 |
+
#pragma unroll
|
196 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
197 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
198 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
199 |
+
|
200 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
201 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
202 |
+
minor_idx] = v;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
210 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
211 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
212 |
+
int pad_y0, int pad_y1) {
|
213 |
+
int curDevice = -1;
|
214 |
+
cudaGetDevice(&curDevice);
|
215 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
216 |
+
|
217 |
+
UpFirDn2DKernelParams p;
|
218 |
+
|
219 |
+
auto x = input.contiguous();
|
220 |
+
auto k = kernel.contiguous();
|
221 |
+
|
222 |
+
p.major_dim = x.size(0);
|
223 |
+
p.in_h = x.size(1);
|
224 |
+
p.in_w = x.size(2);
|
225 |
+
p.minor_dim = x.size(3);
|
226 |
+
p.kernel_h = k.size(0);
|
227 |
+
p.kernel_w = k.size(1);
|
228 |
+
p.up_x = up_x;
|
229 |
+
p.up_y = up_y;
|
230 |
+
p.down_x = down_x;
|
231 |
+
p.down_y = down_y;
|
232 |
+
p.pad_x0 = pad_x0;
|
233 |
+
p.pad_x1 = pad_x1;
|
234 |
+
p.pad_y0 = pad_y0;
|
235 |
+
p.pad_y1 = pad_y1;
|
236 |
+
|
237 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
238 |
+
p.down_y;
|
239 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
240 |
+
p.down_x;
|
241 |
+
|
242 |
+
auto out =
|
243 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
244 |
+
|
245 |
+
int mode = -1;
|
246 |
+
|
247 |
+
int tile_out_h = -1;
|
248 |
+
int tile_out_w = -1;
|
249 |
+
|
250 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
251 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
252 |
+
mode = 1;
|
253 |
+
tile_out_h = 16;
|
254 |
+
tile_out_w = 64;
|
255 |
+
}
|
256 |
+
|
257 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
258 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
259 |
+
mode = 2;
|
260 |
+
tile_out_h = 16;
|
261 |
+
tile_out_w = 64;
|
262 |
+
}
|
263 |
+
|
264 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
265 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
266 |
+
mode = 3;
|
267 |
+
tile_out_h = 16;
|
268 |
+
tile_out_w = 64;
|
269 |
+
}
|
270 |
+
|
271 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
272 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
273 |
+
mode = 4;
|
274 |
+
tile_out_h = 16;
|
275 |
+
tile_out_w = 64;
|
276 |
+
}
|
277 |
+
|
278 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
279 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
280 |
+
mode = 5;
|
281 |
+
tile_out_h = 8;
|
282 |
+
tile_out_w = 32;
|
283 |
+
}
|
284 |
+
|
285 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
286 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
287 |
+
mode = 6;
|
288 |
+
tile_out_h = 8;
|
289 |
+
tile_out_w = 32;
|
290 |
+
}
|
291 |
+
|
292 |
+
dim3 block_size;
|
293 |
+
dim3 grid_size;
|
294 |
+
|
295 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
296 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
297 |
+
p.loop_x = 1;
|
298 |
+
block_size = dim3(32 * 8, 1, 1);
|
299 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
300 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
301 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
302 |
+
} else {
|
303 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
304 |
+
p.loop_x = 4;
|
305 |
+
block_size = dim3(4, 32, 1);
|
306 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
307 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
308 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
309 |
+
}
|
310 |
+
|
311 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
312 |
+
switch (mode) {
|
313 |
+
case 1:
|
314 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
315 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
316 |
+
x.data_ptr<scalar_t>(),
|
317 |
+
k.data_ptr<scalar_t>(), p);
|
318 |
+
|
319 |
+
break;
|
320 |
+
|
321 |
+
case 2:
|
322 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
323 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
324 |
+
x.data_ptr<scalar_t>(),
|
325 |
+
k.data_ptr<scalar_t>(), p);
|
326 |
+
|
327 |
+
break;
|
328 |
+
|
329 |
+
case 3:
|
330 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
331 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
332 |
+
x.data_ptr<scalar_t>(),
|
333 |
+
k.data_ptr<scalar_t>(), p);
|
334 |
+
|
335 |
+
break;
|
336 |
+
|
337 |
+
case 4:
|
338 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
339 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
340 |
+
x.data_ptr<scalar_t>(),
|
341 |
+
k.data_ptr<scalar_t>(), p);
|
342 |
+
|
343 |
+
break;
|
344 |
+
|
345 |
+
case 5:
|
346 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
347 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
348 |
+
x.data_ptr<scalar_t>(),
|
349 |
+
k.data_ptr<scalar_t>(), p);
|
350 |
+
|
351 |
+
break;
|
352 |
+
|
353 |
+
case 6:
|
354 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
355 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
356 |
+
x.data_ptr<scalar_t>(),
|
357 |
+
k.data_ptr<scalar_t>(), p);
|
358 |
+
|
359 |
+
break;
|
360 |
+
|
361 |
+
default:
|
362 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
363 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
364 |
+
k.data_ptr<scalar_t>(), p);
|
365 |
+
}
|
366 |
+
});
|
367 |
+
|
368 |
+
return out;
|
369 |
+
}
|
sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Layers used for up-sampling or down-sampling images.
|
2 |
+
|
3 |
+
Many functions are ported from https://github.com/NVlabs/stylegan2.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from .op import upfirdn2d
|
11 |
+
|
12 |
+
|
13 |
+
# Function ported from StyleGAN2
|
14 |
+
def get_weight(module,
|
15 |
+
shape,
|
16 |
+
weight_var='weight',
|
17 |
+
kernel_init=None):
|
18 |
+
"""Get/create weight tensor for a convolution or fully-connected layer."""
|
19 |
+
|
20 |
+
return module.param(weight_var, kernel_init, shape)
|
21 |
+
|
22 |
+
|
23 |
+
class Conv2d(nn.Module):
|
24 |
+
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
|
25 |
+
|
26 |
+
def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
|
27 |
+
resample_kernel=(1, 3, 3, 1),
|
28 |
+
use_bias=True,
|
29 |
+
kernel_init=None):
|
30 |
+
super().__init__()
|
31 |
+
assert not (up and down)
|
32 |
+
assert kernel >= 1 and kernel % 2 == 1
|
33 |
+
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
|
34 |
+
if kernel_init is not None:
|
35 |
+
self.weight.data = kernel_init(self.weight.data.shape)
|
36 |
+
if use_bias:
|
37 |
+
self.bias = nn.Parameter(torch.zeros(out_ch))
|
38 |
+
|
39 |
+
self.up = up
|
40 |
+
self.down = down
|
41 |
+
self.resample_kernel = resample_kernel
|
42 |
+
self.kernel = kernel
|
43 |
+
self.use_bias = use_bias
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
if self.up:
|
47 |
+
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
|
48 |
+
elif self.down:
|
49 |
+
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
|
50 |
+
else:
|
51 |
+
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
|
52 |
+
|
53 |
+
if self.use_bias:
|
54 |
+
x = x + self.bias.reshape(1, -1, 1, 1)
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def naive_upsample_2d(x, factor=2):
|
60 |
+
_N, C, H, W = x.shape
|
61 |
+
x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
62 |
+
x = x.repeat(1, 1, 1, factor, 1, factor)
|
63 |
+
return torch.reshape(x, (-1, C, H * factor, W * factor))
|
64 |
+
|
65 |
+
|
66 |
+
def naive_downsample_2d(x, factor=2):
|
67 |
+
_N, C, H, W = x.shape
|
68 |
+
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
69 |
+
return torch.mean(x, dim=(3, 5))
|
70 |
+
|
71 |
+
|
72 |
+
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
73 |
+
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
74 |
+
|
75 |
+
Padding is performed only once at the beginning, not between the
|
76 |
+
operations.
|
77 |
+
The fused op is considerably more efficient than performing the same
|
78 |
+
calculation
|
79 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
80 |
+
Args:
|
81 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
82 |
+
C]`.
|
83 |
+
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
84 |
+
outChannels]`. Grouped convolution can be performed by `inChannels =
|
85 |
+
x.shape[0] // numGroups`.
|
86 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
87 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
88 |
+
nearest-neighbor upsampling.
|
89 |
+
factor: Integer upsampling factor (default: 2).
|
90 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
94 |
+
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
95 |
+
"""
|
96 |
+
|
97 |
+
assert isinstance(factor, int) and factor >= 1
|
98 |
+
|
99 |
+
# Check weight shape.
|
100 |
+
assert len(w.shape) == 4
|
101 |
+
convH = w.shape[2]
|
102 |
+
convW = w.shape[3]
|
103 |
+
inC = w.shape[1]
|
104 |
+
outC = w.shape[0]
|
105 |
+
|
106 |
+
assert convW == convH
|
107 |
+
|
108 |
+
# Setup filter kernel.
|
109 |
+
if k is None:
|
110 |
+
k = [1] * factor
|
111 |
+
k = _setup_kernel(k) * (gain * (factor ** 2))
|
112 |
+
p = (k.shape[0] - factor) - (convW - 1)
|
113 |
+
|
114 |
+
stride = (factor, factor)
|
115 |
+
|
116 |
+
# Determine data dimensions.
|
117 |
+
stride = [1, 1, factor, factor]
|
118 |
+
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
|
119 |
+
output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
|
120 |
+
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
|
121 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
122 |
+
num_groups = _shape(x, 1) // inC
|
123 |
+
|
124 |
+
# Transpose weights.
|
125 |
+
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
126 |
+
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
127 |
+
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
128 |
+
|
129 |
+
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
130 |
+
## Original TF code.
|
131 |
+
# x = tf.nn.conv2d_transpose(
|
132 |
+
# x,
|
133 |
+
# w,
|
134 |
+
# output_shape=output_shape,
|
135 |
+
# strides=stride,
|
136 |
+
# padding='VALID',
|
137 |
+
# data_format=data_format)
|
138 |
+
## JAX equivalent
|
139 |
+
|
140 |
+
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
141 |
+
pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
142 |
+
|
143 |
+
|
144 |
+
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
145 |
+
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
146 |
+
|
147 |
+
Padding is performed only once at the beginning, not between the operations.
|
148 |
+
The fused op is considerably more efficient than performing the same
|
149 |
+
calculation
|
150 |
+
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
151 |
+
Args:
|
152 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
153 |
+
C]`.
|
154 |
+
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
155 |
+
outChannels]`. Grouped convolution can be performed by `inChannels =
|
156 |
+
x.shape[0] // numGroups`.
|
157 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
158 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
159 |
+
average pooling.
|
160 |
+
factor: Integer downsampling factor (default: 2).
|
161 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
165 |
+
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
166 |
+
"""
|
167 |
+
|
168 |
+
assert isinstance(factor, int) and factor >= 1
|
169 |
+
_outC, _inC, convH, convW = w.shape
|
170 |
+
assert convW == convH
|
171 |
+
if k is None:
|
172 |
+
k = [1] * factor
|
173 |
+
k = _setup_kernel(k) * gain
|
174 |
+
p = (k.shape[0] - factor) + (convW - 1)
|
175 |
+
s = [factor, factor]
|
176 |
+
x = upfirdn2d(x, torch.tensor(k, device=x.device),
|
177 |
+
pad=((p + 1) // 2, p // 2))
|
178 |
+
return F.conv2d(x, w, stride=s, padding=0)
|
179 |
+
|
180 |
+
|
181 |
+
def _setup_kernel(k):
|
182 |
+
k = np.asarray(k, dtype=np.float32)
|
183 |
+
if k.ndim == 1:
|
184 |
+
k = np.outer(k, k)
|
185 |
+
k /= np.sum(k)
|
186 |
+
assert k.ndim == 2
|
187 |
+
assert k.shape[0] == k.shape[1]
|
188 |
+
return k
|
189 |
+
|
190 |
+
|
191 |
+
def _shape(x, dim):
|
192 |
+
return x.shape[dim]
|
193 |
+
|
194 |
+
|
195 |
+
def upsample_2d(x, k=None, factor=2, gain=1):
|
196 |
+
r"""Upsample a batch of 2D images with the given filter.
|
197 |
+
|
198 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
199 |
+
and upsamples each image with the given filter. The filter is normalized so
|
200 |
+
that
|
201 |
+
if the input pixels are constant, they will be scaled by the specified
|
202 |
+
`gain`.
|
203 |
+
Pixels outside the image are assumed to be zero, and the filter is padded
|
204 |
+
with
|
205 |
+
zeros so that its shape is a multiple of the upsampling factor.
|
206 |
+
Args:
|
207 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
208 |
+
C]`.
|
209 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
210 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
211 |
+
nearest-neighbor upsampling.
|
212 |
+
factor: Integer upsampling factor (default: 2).
|
213 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Tensor of the shape `[N, C, H * factor, W * factor]`
|
217 |
+
"""
|
218 |
+
assert isinstance(factor, int) and factor >= 1
|
219 |
+
if k is None:
|
220 |
+
k = [1] * factor
|
221 |
+
k = _setup_kernel(k) * (gain * (factor ** 2))
|
222 |
+
p = k.shape[0] - factor
|
223 |
+
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
224 |
+
up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
225 |
+
|
226 |
+
|
227 |
+
def downsample_2d(x, k=None, factor=2, gain=1):
|
228 |
+
r"""Downsample a batch of 2D images with the given filter.
|
229 |
+
|
230 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
231 |
+
and downsamples each image with the given filter. The filter is normalized
|
232 |
+
so that
|
233 |
+
if the input pixels are constant, they will be scaled by the specified
|
234 |
+
`gain`.
|
235 |
+
Pixels outside the image are assumed to be zero, and the filter is padded
|
236 |
+
with
|
237 |
+
zeros so that its shape is a multiple of the downsampling factor.
|
238 |
+
Args:
|
239 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
240 |
+
C]`.
|
241 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
242 |
+
(separable). The default is `[1] * factor`, which corresponds to
|
243 |
+
average pooling.
|
244 |
+
factor: Integer downsampling factor (default: 2).
|
245 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Tensor of the shape `[N, C, H // factor, W // factor]`
|
249 |
+
"""
|
250 |
+
|
251 |
+
assert isinstance(factor, int) and factor >= 1
|
252 |
+
if k is None:
|
253 |
+
k = [1] * factor
|
254 |
+
k = _setup_kernel(k) * gain
|
255 |
+
p = k.shape[0] - factor
|
256 |
+
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
257 |
+
down=factor, pad=((p + 1) // 2, p // 2))
|
sgmse/backbones/ncsnpp_utils/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""All functions and modules related to model definition.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from ...sdes import OUVESDE, OUVPSDE
|
23 |
+
|
24 |
+
|
25 |
+
_MODELS = {}
|
26 |
+
|
27 |
+
|
28 |
+
def register_model(cls=None, *, name=None):
|
29 |
+
"""A decorator for registering model classes."""
|
30 |
+
|
31 |
+
def _register(cls):
|
32 |
+
if name is None:
|
33 |
+
local_name = cls.__name__
|
34 |
+
else:
|
35 |
+
local_name = name
|
36 |
+
if local_name in _MODELS:
|
37 |
+
raise ValueError(f'Already registered model with name: {local_name}')
|
38 |
+
_MODELS[local_name] = cls
|
39 |
+
return cls
|
40 |
+
|
41 |
+
if cls is None:
|
42 |
+
return _register
|
43 |
+
else:
|
44 |
+
return _register(cls)
|
45 |
+
|
46 |
+
|
47 |
+
def get_model(name):
|
48 |
+
return _MODELS[name]
|
49 |
+
|
50 |
+
|
51 |
+
def get_sigmas(sigma_min, sigma_max, num_scales):
|
52 |
+
"""Get sigmas --- the set of noise levels for SMLD from config files.
|
53 |
+
Args:
|
54 |
+
config: A ConfigDict object parsed from the config file
|
55 |
+
Returns:
|
56 |
+
sigmas: a jax numpy arrary of noise levels
|
57 |
+
"""
|
58 |
+
sigmas = np.exp(
|
59 |
+
np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
|
60 |
+
|
61 |
+
return sigmas
|
62 |
+
|
63 |
+
|
64 |
+
def get_ddpm_params(config):
|
65 |
+
"""Get betas and alphas --- parameters used in the original DDPM paper."""
|
66 |
+
num_diffusion_timesteps = 1000
|
67 |
+
# parameters need to be adapted if number of time steps differs from 1000
|
68 |
+
beta_start = config.model.beta_min / config.model.num_scales
|
69 |
+
beta_end = config.model.beta_max / config.model.num_scales
|
70 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
71 |
+
|
72 |
+
alphas = 1. - betas
|
73 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
74 |
+
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
75 |
+
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
|
76 |
+
|
77 |
+
return {
|
78 |
+
'betas': betas,
|
79 |
+
'alphas': alphas,
|
80 |
+
'alphas_cumprod': alphas_cumprod,
|
81 |
+
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
|
82 |
+
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
|
83 |
+
'beta_min': beta_start * (num_diffusion_timesteps - 1),
|
84 |
+
'beta_max': beta_end * (num_diffusion_timesteps - 1),
|
85 |
+
'num_diffusion_timesteps': num_diffusion_timesteps
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
def create_model(config):
|
90 |
+
"""Create the score model."""
|
91 |
+
model_name = config.model.name
|
92 |
+
score_model = get_model(model_name)(config)
|
93 |
+
score_model = score_model.to(config.device)
|
94 |
+
score_model = torch.nn.DataParallel(score_model)
|
95 |
+
return score_model
|
96 |
+
|
97 |
+
|
98 |
+
def get_model_fn(model, train=False):
|
99 |
+
"""Create a function to give the output of the score-based model.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model: The score model.
|
103 |
+
train: `True` for training and `False` for evaluation.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
A model function.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def model_fn(x, labels):
|
110 |
+
"""Compute the output of the score-based model.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x: A mini-batch of input data.
|
114 |
+
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
115 |
+
for different models.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A tuple of (model output, new mutable states)
|
119 |
+
"""
|
120 |
+
if not train:
|
121 |
+
model.eval()
|
122 |
+
return model(x, labels)
|
123 |
+
else:
|
124 |
+
model.train()
|
125 |
+
return model(x, labels)
|
126 |
+
|
127 |
+
return model_fn
|
128 |
+
|
129 |
+
|
130 |
+
def get_score_fn(sde, model, train=False, continuous=False):
|
131 |
+
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
135 |
+
model: A score model.
|
136 |
+
train: `True` for training and `False` for evaluation.
|
137 |
+
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
A score function.
|
141 |
+
"""
|
142 |
+
model_fn = get_model_fn(model, train=train)
|
143 |
+
|
144 |
+
if isinstance(sde, OUVPSDE):
|
145 |
+
def score_fn(x, t):
|
146 |
+
# Scale neural network output by standard deviation and flip sign
|
147 |
+
if continuous:
|
148 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
149 |
+
# The maximum value of time embedding is assumed to 999 for
|
150 |
+
# continuously-trained models.
|
151 |
+
labels = t * 999
|
152 |
+
score = model_fn(x, labels)
|
153 |
+
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
154 |
+
else:
|
155 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
156 |
+
labels = t * (sde.N - 1)
|
157 |
+
score = model_fn(x, labels)
|
158 |
+
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
|
159 |
+
|
160 |
+
score = -score / std[:, None, None, None]
|
161 |
+
return score
|
162 |
+
|
163 |
+
elif isinstance(sde, OUVESDE):
|
164 |
+
def score_fn(x, t):
|
165 |
+
if continuous:
|
166 |
+
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
167 |
+
else:
|
168 |
+
# For VE-trained models, t=0 corresponds to the highest noise level
|
169 |
+
labels = sde.T - t
|
170 |
+
labels *= sde.N - 1
|
171 |
+
labels = torch.round(labels).long()
|
172 |
+
|
173 |
+
score = model_fn(x, labels)
|
174 |
+
return score
|
175 |
+
|
176 |
+
else:
|
177 |
+
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
178 |
+
|
179 |
+
return score_fn
|
180 |
+
|
181 |
+
|
182 |
+
def to_flattened_numpy(x):
|
183 |
+
"""Flatten a torch tensor `x` and convert it to numpy."""
|
184 |
+
return x.detach().cpu().numpy().reshape((-1,))
|
185 |
+
|
186 |
+
|
187 |
+
def from_flattened_numpy(x, shape):
|
188 |
+
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
189 |
+
return torch.from_numpy(x.reshape(shape))
|
sgmse/backbones/shared.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from sgmse.util.registry import Registry
|
8 |
+
|
9 |
+
|
10 |
+
BackboneRegistry = Registry("Backbone")
|
11 |
+
|
12 |
+
|
13 |
+
class GaussianFourierProjection(nn.Module):
|
14 |
+
"""Gaussian random features for encoding time steps."""
|
15 |
+
|
16 |
+
def __init__(self, embed_dim, scale=16, complex_valued=False):
|
17 |
+
super().__init__()
|
18 |
+
self.complex_valued = complex_valued
|
19 |
+
if not complex_valued:
|
20 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
21 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
22 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
23 |
+
# and this halving is not necessary.
|
24 |
+
embed_dim = embed_dim // 2
|
25 |
+
# Randomly sample weights during initialization. These weights are fixed
|
26 |
+
# during optimization and are not trainable.
|
27 |
+
self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
|
28 |
+
|
29 |
+
def forward(self, t):
|
30 |
+
t_proj = t[:, None] * self.W[None, :] * 2*np.pi
|
31 |
+
if self.complex_valued:
|
32 |
+
return torch.exp(1j * t_proj)
|
33 |
+
else:
|
34 |
+
return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
|
35 |
+
|
36 |
+
|
37 |
+
class DiffusionStepEmbedding(nn.Module):
|
38 |
+
"""Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
|
39 |
+
|
40 |
+
def __init__(self, embed_dim, complex_valued=False):
|
41 |
+
super().__init__()
|
42 |
+
self.complex_valued = complex_valued
|
43 |
+
if not complex_valued:
|
44 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
45 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
46 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
47 |
+
# and this halving is not necessary.
|
48 |
+
embed_dim = embed_dim // 2
|
49 |
+
self.embed_dim = embed_dim
|
50 |
+
|
51 |
+
def forward(self, t):
|
52 |
+
fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
|
53 |
+
inner = t[:, None] * fac[None, :]
|
54 |
+
if self.complex_valued:
|
55 |
+
return torch.exp(1j * inner)
|
56 |
+
else:
|
57 |
+
return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
|
58 |
+
|
59 |
+
|
60 |
+
class ComplexLinear(nn.Module):
|
61 |
+
"""A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
|
62 |
+
def __init__(self, input_dim, output_dim, complex_valued):
|
63 |
+
super().__init__()
|
64 |
+
self.complex_valued = complex_valued
|
65 |
+
if self.complex_valued:
|
66 |
+
self.re = nn.Linear(input_dim, output_dim)
|
67 |
+
self.im = nn.Linear(input_dim, output_dim)
|
68 |
+
else:
|
69 |
+
self.lin = nn.Linear(input_dim, output_dim)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.complex_valued:
|
73 |
+
return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
|
74 |
+
else:
|
75 |
+
return self.lin(x)
|
76 |
+
|
77 |
+
|
78 |
+
class FeatureMapDense(nn.Module):
|
79 |
+
"""A fully connected layer that reshapes outputs to feature maps."""
|
80 |
+
|
81 |
+
def __init__(self, input_dim, output_dim, complex_valued=False):
|
82 |
+
super().__init__()
|
83 |
+
self.complex_valued = complex_valued
|
84 |
+
self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
return self.dense(x)[..., None, None]
|
88 |
+
|
89 |
+
|
90 |
+
def torch_complex_from_reim(re, im):
|
91 |
+
return torch.view_as_complex(torch.stack([re, im], dim=-1))
|
92 |
+
|
93 |
+
|
94 |
+
class ArgsComplexMultiplicationWrapper(nn.Module):
|
95 |
+
"""Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
|
96 |
+
|
97 |
+
Make a complex-valued module `F` from a real-valued module `f` by applying
|
98 |
+
complex multiplication rules:
|
99 |
+
|
100 |
+
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
|
101 |
+
|
102 |
+
where `f1`, `f2` are instances of `f` that do *not* share weights.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
module_cls (callable): A class or function that returns a Torch module/functional.
|
106 |
+
Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
|
107 |
+
to construct the real and imaginary component modules.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, module_cls, *args, **kwargs):
|
111 |
+
super().__init__()
|
112 |
+
self.re_module = module_cls(*args, **kwargs)
|
113 |
+
self.im_module = module_cls(*args, **kwargs)
|
114 |
+
|
115 |
+
def forward(self, x, *args, **kwargs):
|
116 |
+
return torch_complex_from_reim(
|
117 |
+
self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
|
118 |
+
self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
|
123 |
+
ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
|
sgmse/data_module.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from os.path import join
|
3 |
+
import torch
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from glob import glob
|
8 |
+
from torchaudio import load
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def get_window(window_type, window_length):
|
14 |
+
if window_type == 'sqrthann':
|
15 |
+
return torch.sqrt(torch.hann_window(window_length, periodic=True))
|
16 |
+
elif window_type == 'hann':
|
17 |
+
return torch.hann_window(window_length, periodic=True)
|
18 |
+
else:
|
19 |
+
raise NotImplementedError(f"Window type {window_type} not implemented!")
|
20 |
+
|
21 |
+
|
22 |
+
class Specs(Dataset):
|
23 |
+
def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames,
|
24 |
+
format='default', normalize="noisy", spec_transform=None,
|
25 |
+
stft_kwargs=None, **ignored_kwargs):
|
26 |
+
|
27 |
+
# Read file paths according to file naming format.
|
28 |
+
if format == "default":
|
29 |
+
self.clean_files = []
|
30 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
|
31 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
|
32 |
+
self.noisy_files = []
|
33 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
|
34 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
|
35 |
+
elif format == "reverb":
|
36 |
+
self.clean_files = []
|
37 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
|
38 |
+
self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
|
39 |
+
self.noisy_files = []
|
40 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
|
41 |
+
self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
|
42 |
+
else:
|
43 |
+
# Feel free to add your own directory format
|
44 |
+
raise NotImplementedError(f"Directory format {format} unknown!")
|
45 |
+
|
46 |
+
self.dummy = dummy
|
47 |
+
self.num_frames = num_frames
|
48 |
+
self.shuffle_spec = shuffle_spec
|
49 |
+
self.normalize = normalize
|
50 |
+
self.spec_transform = spec_transform
|
51 |
+
|
52 |
+
assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
|
53 |
+
self.stft_kwargs = stft_kwargs
|
54 |
+
self.hop_length = self.stft_kwargs["hop_length"]
|
55 |
+
assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
|
56 |
+
|
57 |
+
def __getitem__(self, i):
|
58 |
+
x, _ = load(self.clean_files[i])
|
59 |
+
y, _ = load(self.noisy_files[i])
|
60 |
+
|
61 |
+
# formula applies for center=True
|
62 |
+
target_len = (self.num_frames - 1) * self.hop_length
|
63 |
+
current_len = x.size(-1)
|
64 |
+
pad = max(target_len - current_len, 0)
|
65 |
+
if pad == 0:
|
66 |
+
# extract random part of the audio file
|
67 |
+
if self.shuffle_spec:
|
68 |
+
start = int(np.random.uniform(0, current_len-target_len))
|
69 |
+
else:
|
70 |
+
start = int((current_len-target_len)/2)
|
71 |
+
x = x[..., start:start+target_len]
|
72 |
+
y = y[..., start:start+target_len]
|
73 |
+
else:
|
74 |
+
# pad audio if the length T is smaller than num_frames
|
75 |
+
x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
|
76 |
+
y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
|
77 |
+
|
78 |
+
# normalize w.r.t to the noisy or the clean signal or not at all
|
79 |
+
# to ensure same clean signal power in x and y.
|
80 |
+
if self.normalize == "noisy":
|
81 |
+
normfac = y.abs().max()
|
82 |
+
elif self.normalize == "clean":
|
83 |
+
normfac = x.abs().max()
|
84 |
+
elif self.normalize == "not":
|
85 |
+
normfac = 1.0
|
86 |
+
x = x / normfac
|
87 |
+
y = y / normfac
|
88 |
+
|
89 |
+
X = torch.stft(x, **self.stft_kwargs)
|
90 |
+
Y = torch.stft(y, **self.stft_kwargs)
|
91 |
+
|
92 |
+
X, Y = self.spec_transform(X), self.spec_transform(Y)
|
93 |
+
return X, Y
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
if self.dummy:
|
97 |
+
# for debugging shrink the data set size
|
98 |
+
return int(len(self.clean_files)/200)
|
99 |
+
else:
|
100 |
+
return len(self.clean_files)
|
101 |
+
|
102 |
+
|
103 |
+
class SpecsDataModule(pl.LightningDataModule):
|
104 |
+
@staticmethod
|
105 |
+
def add_argparse_args(parser):
|
106 |
+
parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
|
107 |
+
parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
|
108 |
+
parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
|
109 |
+
parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins
|
110 |
+
parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
|
111 |
+
parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
|
112 |
+
parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
|
113 |
+
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to use for DataLoaders. 4 by default.")
|
114 |
+
parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
|
115 |
+
parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
|
116 |
+
parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
|
117 |
+
parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
|
118 |
+
parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
|
119 |
+
return parser
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self, base_dir, format='default', batch_size=8,
|
123 |
+
n_fft=510, hop_length=128, num_frames=256, window='hann',
|
124 |
+
num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
|
125 |
+
gpu=True, normalize='noisy', transform_type="exponent", **kwargs
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.base_dir = base_dir
|
129 |
+
self.format = format
|
130 |
+
self.batch_size = batch_size
|
131 |
+
self.n_fft = n_fft
|
132 |
+
self.hop_length = hop_length
|
133 |
+
self.num_frames = num_frames
|
134 |
+
self.window = get_window(window, self.n_fft)
|
135 |
+
self.windows = {}
|
136 |
+
self.num_workers = num_workers
|
137 |
+
self.dummy = dummy
|
138 |
+
self.spec_factor = spec_factor
|
139 |
+
self.spec_abs_exponent = spec_abs_exponent
|
140 |
+
self.gpu = gpu
|
141 |
+
self.normalize = normalize
|
142 |
+
self.transform_type = transform_type
|
143 |
+
self.kwargs = kwargs
|
144 |
+
|
145 |
+
def setup(self, stage=None):
|
146 |
+
specs_kwargs = dict(
|
147 |
+
stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
|
148 |
+
spec_transform=self.spec_fwd, **self.kwargs
|
149 |
+
)
|
150 |
+
if stage == 'fit' or stage is None:
|
151 |
+
self.train_set = Specs(data_dir=self.base_dir, subset='train',
|
152 |
+
dummy=self.dummy, shuffle_spec=True, format=self.format,
|
153 |
+
normalize=self.normalize, **specs_kwargs)
|
154 |
+
self.valid_set = Specs(data_dir=self.base_dir, subset='valid',
|
155 |
+
dummy=self.dummy, shuffle_spec=False, format=self.format,
|
156 |
+
normalize=self.normalize, **specs_kwargs)
|
157 |
+
if stage == 'test' or stage is None:
|
158 |
+
self.test_set = Specs(data_dir=self.base_dir, subset='test',
|
159 |
+
dummy=self.dummy, shuffle_spec=False, format=self.format,
|
160 |
+
normalize=self.normalize, **specs_kwargs)
|
161 |
+
|
162 |
+
def spec_fwd(self, spec):
|
163 |
+
if self.transform_type == "exponent":
|
164 |
+
if self.spec_abs_exponent != 1:
|
165 |
+
# only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
|
166 |
+
# and introduced numerical error
|
167 |
+
e = self.spec_abs_exponent
|
168 |
+
spec = spec.abs()**e * torch.exp(1j * spec.angle())
|
169 |
+
spec = spec * self.spec_factor
|
170 |
+
elif self.transform_type == "log":
|
171 |
+
spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
|
172 |
+
spec = spec * self.spec_factor
|
173 |
+
elif self.transform_type == "none":
|
174 |
+
spec = spec
|
175 |
+
return spec
|
176 |
+
|
177 |
+
def spec_back(self, spec):
|
178 |
+
if self.transform_type == "exponent":
|
179 |
+
spec = spec / self.spec_factor
|
180 |
+
if self.spec_abs_exponent != 1:
|
181 |
+
e = self.spec_abs_exponent
|
182 |
+
spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
|
183 |
+
elif self.transform_type == "log":
|
184 |
+
spec = spec / self.spec_factor
|
185 |
+
spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
|
186 |
+
elif self.transform_type == "none":
|
187 |
+
spec = spec
|
188 |
+
return spec
|
189 |
+
|
190 |
+
@property
|
191 |
+
def stft_kwargs(self):
|
192 |
+
return {**self.istft_kwargs, "return_complex": True}
|
193 |
+
|
194 |
+
@property
|
195 |
+
def istft_kwargs(self):
|
196 |
+
return dict(
|
197 |
+
n_fft=self.n_fft, hop_length=self.hop_length,
|
198 |
+
window=self.window, center=True
|
199 |
+
)
|
200 |
+
|
201 |
+
def _get_window(self, x):
|
202 |
+
"""
|
203 |
+
Retrieve an appropriate window for the given tensor x, matching the device.
|
204 |
+
Caches the retrieved windows so that only one window tensor will be allocated per device.
|
205 |
+
"""
|
206 |
+
window = self.windows.get(x.device, None)
|
207 |
+
if window is None:
|
208 |
+
window = self.window.to(x.device)
|
209 |
+
self.windows[x.device] = window
|
210 |
+
return window
|
211 |
+
|
212 |
+
def stft(self, sig):
|
213 |
+
window = self._get_window(sig)
|
214 |
+
return torch.stft(sig, **{**self.stft_kwargs, "window": window})
|
215 |
+
|
216 |
+
def istft(self, spec, length=None):
|
217 |
+
window = self._get_window(spec)
|
218 |
+
return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
|
219 |
+
|
220 |
+
def train_dataloader(self):
|
221 |
+
return DataLoader(
|
222 |
+
self.train_set, batch_size=self.batch_size,
|
223 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
|
224 |
+
)
|
225 |
+
|
226 |
+
def val_dataloader(self):
|
227 |
+
return DataLoader(
|
228 |
+
self.valid_set, batch_size=self.batch_size,
|
229 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
|
230 |
+
)
|
231 |
+
|
232 |
+
def test_dataloader(self):
|
233 |
+
return DataLoader(
|
234 |
+
self.test_set, batch_size=self.batch_size,
|
235 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
|
236 |
+
)
|
sgmse/model.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from math import ceil
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from torch_ema import ExponentialMovingAverage
|
8 |
+
|
9 |
+
from sgmse import sampling
|
10 |
+
from sgmse.sdes import SDERegistry
|
11 |
+
from sgmse.backbones import BackboneRegistry
|
12 |
+
from sgmse.util.inference import evaluate_model
|
13 |
+
from sgmse.util.other import pad_spec
|
14 |
+
|
15 |
+
|
16 |
+
class ScoreModel(pl.LightningModule):
|
17 |
+
@staticmethod
|
18 |
+
def add_argparse_args(parser):
|
19 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
|
20 |
+
parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
|
21 |
+
parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
|
22 |
+
parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
|
23 |
+
parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
|
24 |
+
return parser
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
|
28 |
+
num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Create a new ScoreModel.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
backbone: Backbone DNN that serves as a score-based model.
|
35 |
+
sde: The SDE that defines the diffusion process.
|
36 |
+
lr: The learning rate of the optimizer. (1e-4 by default).
|
37 |
+
ema_decay: The decay constant of the parameter EMA (0.999 by default).
|
38 |
+
t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
|
39 |
+
loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
# Initialize Backbone DNN
|
43 |
+
self.backbone = backbone
|
44 |
+
dnn_cls = BackboneRegistry.get_by_name(backbone)
|
45 |
+
self.dnn = dnn_cls(**kwargs)
|
46 |
+
# Initialize SDE
|
47 |
+
sde_cls = SDERegistry.get_by_name(sde)
|
48 |
+
self.sde = sde_cls(**kwargs)
|
49 |
+
# Store hyperparams and save them
|
50 |
+
self.lr = lr
|
51 |
+
self.ema_decay = ema_decay
|
52 |
+
self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
|
53 |
+
self._error_loading_ema = False
|
54 |
+
self.t_eps = t_eps
|
55 |
+
self.loss_type = loss_type
|
56 |
+
self.num_eval_files = num_eval_files
|
57 |
+
|
58 |
+
self.save_hyperparameters(ignore=['no_wandb'])
|
59 |
+
self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
|
60 |
+
|
61 |
+
def configure_optimizers(self):
|
62 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
63 |
+
return optimizer
|
64 |
+
|
65 |
+
def optimizer_step(self, *args, **kwargs):
|
66 |
+
# Method overridden so that the EMA params are updated after each optimizer step
|
67 |
+
super().optimizer_step(*args, **kwargs)
|
68 |
+
self.ema.update(self.parameters())
|
69 |
+
|
70 |
+
# on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
|
71 |
+
def on_load_checkpoint(self, checkpoint):
|
72 |
+
ema = checkpoint.get('ema', None)
|
73 |
+
if ema is not None:
|
74 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
75 |
+
else:
|
76 |
+
self._error_loading_ema = True
|
77 |
+
warnings.warn("EMA state_dict not found in checkpoint!")
|
78 |
+
|
79 |
+
def on_save_checkpoint(self, checkpoint):
|
80 |
+
checkpoint['ema'] = self.ema.state_dict()
|
81 |
+
|
82 |
+
def train(self, mode, no_ema=False):
|
83 |
+
res = super().train(mode) # call the standard `train` method with the given mode
|
84 |
+
if not self._error_loading_ema:
|
85 |
+
if mode == False and not no_ema:
|
86 |
+
# eval
|
87 |
+
self.ema.store(self.parameters()) # store current params in EMA
|
88 |
+
self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
|
89 |
+
else:
|
90 |
+
# train
|
91 |
+
if self.ema.collected_params is not None:
|
92 |
+
self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
|
93 |
+
return res
|
94 |
+
|
95 |
+
def eval(self, no_ema=False):
|
96 |
+
return self.train(False, no_ema=no_ema)
|
97 |
+
|
98 |
+
def _loss(self, err):
|
99 |
+
if self.loss_type == 'mse':
|
100 |
+
losses = torch.square(err.abs())
|
101 |
+
elif self.loss_type == 'mae':
|
102 |
+
losses = err.abs()
|
103 |
+
# taken from reduce_op function: sum over channels and position and mean over batch dim
|
104 |
+
# presumably only important for absolute loss number, not for gradients
|
105 |
+
loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
|
106 |
+
return loss
|
107 |
+
|
108 |
+
def _step(self, batch, batch_idx):
|
109 |
+
x, y = batch
|
110 |
+
t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
|
111 |
+
mean, std = self.sde.marginal_prob(x, t, y)
|
112 |
+
z = torch.randn_like(x) # i.i.d. normal distributed with var=0.5
|
113 |
+
sigmas = std[:, None, None, None]
|
114 |
+
perturbed_data = mean + sigmas * z
|
115 |
+
score = self(perturbed_data, t, y)
|
116 |
+
err = score * sigmas + z
|
117 |
+
loss = self._loss(err)
|
118 |
+
return loss
|
119 |
+
|
120 |
+
def training_step(self, batch, batch_idx):
|
121 |
+
loss = self._step(batch, batch_idx)
|
122 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True)
|
123 |
+
return loss
|
124 |
+
|
125 |
+
def validation_step(self, batch, batch_idx):
|
126 |
+
loss = self._step(batch, batch_idx)
|
127 |
+
self.log('valid_loss', loss, on_step=False, on_epoch=True)
|
128 |
+
|
129 |
+
# Evaluate speech enhancement performance
|
130 |
+
if batch_idx == 0 and self.num_eval_files != 0:
|
131 |
+
pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
|
132 |
+
self.log('pesq', pesq, on_step=False, on_epoch=True)
|
133 |
+
self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
|
134 |
+
self.log('estoi', estoi, on_step=False, on_epoch=True)
|
135 |
+
|
136 |
+
return loss
|
137 |
+
|
138 |
+
def forward(self, x, t, y):
|
139 |
+
# Concatenate y as an extra channel
|
140 |
+
dnn_input = torch.cat([x, y], dim=1)
|
141 |
+
|
142 |
+
# the minus is most likely unimportant here - taken from Song's repo
|
143 |
+
score = -self.dnn(dnn_input, t)
|
144 |
+
return score
|
145 |
+
|
146 |
+
def to(self, *args, **kwargs):
|
147 |
+
"""Override PyTorch .to() to also transfer the EMA of the model weights"""
|
148 |
+
self.ema.to(*args, **kwargs)
|
149 |
+
return super().to(*args, **kwargs)
|
150 |
+
|
151 |
+
def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
|
152 |
+
N = self.sde.N if N is None else N
|
153 |
+
sde = self.sde.copy()
|
154 |
+
sde.N = N
|
155 |
+
|
156 |
+
kwargs = {"eps": self.t_eps, **kwargs}
|
157 |
+
if minibatch is None:
|
158 |
+
return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
|
159 |
+
else:
|
160 |
+
M = y.shape[0]
|
161 |
+
def batched_sampling_fn():
|
162 |
+
samples, ns = [], []
|
163 |
+
for i in range(int(ceil(M / minibatch))):
|
164 |
+
y_mini = y[i*minibatch:(i+1)*minibatch]
|
165 |
+
sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
|
166 |
+
sample, n = sampler()
|
167 |
+
samples.append(sample)
|
168 |
+
ns.append(n)
|
169 |
+
samples = torch.cat(samples, dim=0)
|
170 |
+
return samples, ns
|
171 |
+
return batched_sampling_fn
|
172 |
+
|
173 |
+
def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
|
174 |
+
N = self.sde.N if N is None else N
|
175 |
+
sde = self.sde.copy()
|
176 |
+
sde.N = N
|
177 |
+
|
178 |
+
kwargs = {"eps": self.t_eps, **kwargs}
|
179 |
+
if minibatch is None:
|
180 |
+
return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
|
181 |
+
else:
|
182 |
+
M = y.shape[0]
|
183 |
+
def batched_sampling_fn():
|
184 |
+
samples, ns = [], []
|
185 |
+
for i in range(int(ceil(M / minibatch))):
|
186 |
+
y_mini = y[i*minibatch:(i+1)*minibatch]
|
187 |
+
sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
|
188 |
+
sample, n = sampler()
|
189 |
+
samples.append(sample)
|
190 |
+
ns.append(n)
|
191 |
+
samples = torch.cat(samples, dim=0)
|
192 |
+
return sample, ns
|
193 |
+
return batched_sampling_fn
|
194 |
+
|
195 |
+
def train_dataloader(self):
|
196 |
+
return self.data_module.train_dataloader()
|
197 |
+
|
198 |
+
def val_dataloader(self):
|
199 |
+
return self.data_module.val_dataloader()
|
200 |
+
|
201 |
+
def test_dataloader(self):
|
202 |
+
return self.data_module.test_dataloader()
|
203 |
+
|
204 |
+
def setup(self, stage=None):
|
205 |
+
return self.data_module.setup(stage=stage)
|
206 |
+
|
207 |
+
def to_audio(self, spec, length=None):
|
208 |
+
return self._istft(self._backward_transform(spec), length)
|
209 |
+
|
210 |
+
def _forward_transform(self, spec):
|
211 |
+
return self.data_module.spec_fwd(spec)
|
212 |
+
|
213 |
+
def _backward_transform(self, spec):
|
214 |
+
return self.data_module.spec_back(spec)
|
215 |
+
|
216 |
+
def _stft(self, sig):
|
217 |
+
return self.data_module.stft(sig)
|
218 |
+
|
219 |
+
def _istft(self, spec, length=None):
|
220 |
+
return self.data_module.istft(spec, length)
|
221 |
+
|
222 |
+
def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
|
223 |
+
corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
|
224 |
+
**kwargs
|
225 |
+
):
|
226 |
+
"""
|
227 |
+
One-call speech enhancement of noisy speech `y`, for convenience.
|
228 |
+
"""
|
229 |
+
sr=16000
|
230 |
+
start = time.time()
|
231 |
+
T_orig = y.size(1)
|
232 |
+
norm_factor = y.abs().max().item()
|
233 |
+
y = y / norm_factor
|
234 |
+
Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
|
235 |
+
Y = pad_spec(Y)
|
236 |
+
if sampler_type == "pc":
|
237 |
+
sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N,
|
238 |
+
corrector_steps=corrector_steps, snr=snr, intermediate=False,
|
239 |
+
**kwargs)
|
240 |
+
elif sampler_type == "ode":
|
241 |
+
sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
|
242 |
+
else:
|
243 |
+
print("{} is not a valid sampler type!".format(sampler_type))
|
244 |
+
sample, nfe = sampler()
|
245 |
+
x_hat = self.to_audio(sample.squeeze(), T_orig)
|
246 |
+
x_hat = x_hat * norm_factor
|
247 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
248 |
+
end = time.time()
|
249 |
+
if timeit:
|
250 |
+
rtf = (end-start)/(len(x_hat)/sr)
|
251 |
+
return x_hat, nfe, rtf
|
252 |
+
else:
|
253 |
+
return x_hat
|
sgmse/sampling/__init__.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
|
2 |
+
"""Various sampling methods."""
|
3 |
+
from scipy import integrate
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
|
7 |
+
from .correctors import Corrector, CorrectorRegistry
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
|
12 |
+
'get_sampler'
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
def to_flattened_numpy(x):
|
17 |
+
"""Flatten a torch tensor `x` and convert it to numpy."""
|
18 |
+
return x.detach().cpu().numpy().reshape((-1,))
|
19 |
+
|
20 |
+
|
21 |
+
def from_flattened_numpy(x, shape):
|
22 |
+
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
23 |
+
return torch.from_numpy(x.reshape(shape))
|
24 |
+
|
25 |
+
|
26 |
+
def get_pc_sampler(
|
27 |
+
predictor_name, corrector_name, sde, score_fn, y,
|
28 |
+
denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
|
29 |
+
intermediate=False, **kwargs
|
30 |
+
):
|
31 |
+
"""Create a Predictor-Corrector (PC) sampler.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
predictor_name: The name of a registered `sampling.Predictor`.
|
35 |
+
corrector_name: The name of a registered `sampling.Corrector`.
|
36 |
+
sde: An `sdes.SDE` object representing the forward SDE.
|
37 |
+
score_fn: A function (typically learned model) that predicts the score.
|
38 |
+
y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
|
39 |
+
denoise: If `True`, add one-step denoising to the final samples.
|
40 |
+
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
41 |
+
snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
|
42 |
+
N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
A sampling function that returns samples and the number of function evaluations during sampling.
|
46 |
+
"""
|
47 |
+
predictor_cls = PredictorRegistry.get_by_name(predictor_name)
|
48 |
+
corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
|
49 |
+
predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
|
50 |
+
corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
|
51 |
+
|
52 |
+
def pc_sampler():
|
53 |
+
"""The PC sampler function."""
|
54 |
+
with torch.no_grad():
|
55 |
+
xt = sde.prior_sampling(y.shape, y).to(y.device)
|
56 |
+
timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
|
57 |
+
for i in range(sde.N):
|
58 |
+
t = timesteps[i]
|
59 |
+
if i != len(timesteps) - 1:
|
60 |
+
stepsize = t - timesteps[i+1]
|
61 |
+
else:
|
62 |
+
stepsize = timesteps[-1] # from eps to 0
|
63 |
+
vec_t = torch.ones(y.shape[0], device=y.device) * t
|
64 |
+
xt, xt_mean = corrector.update_fn(xt, vec_t, y)
|
65 |
+
xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize)
|
66 |
+
x_result = xt_mean if denoise else xt
|
67 |
+
ns = sde.N * (corrector.n_steps + 1)
|
68 |
+
return x_result, ns
|
69 |
+
|
70 |
+
return pc_sampler
|
71 |
+
|
72 |
+
|
73 |
+
def get_ode_sampler(
|
74 |
+
sde, score_fn, y, inverse_scaler=None,
|
75 |
+
denoise=True, rtol=1e-5, atol=1e-5,
|
76 |
+
method='RK45', eps=3e-2, device='cuda', **kwargs
|
77 |
+
):
|
78 |
+
"""Probability flow ODE sampler with the black-box ODE solver.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
sde: An `sdes.SDE` object representing the forward SDE.
|
82 |
+
score_fn: A function (typically learned model) that predicts the score.
|
83 |
+
y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
|
84 |
+
inverse_scaler: The inverse data normalizer.
|
85 |
+
denoise: If `True`, add one-step denoising to final samples.
|
86 |
+
rtol: A `float` number. The relative tolerance level of the ODE solver.
|
87 |
+
atol: A `float` number. The absolute tolerance level of the ODE solver.
|
88 |
+
method: A `str`. The algorithm used for the black-box ODE solver.
|
89 |
+
See the documentation of `scipy.integrate.solve_ivp`.
|
90 |
+
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
|
91 |
+
device: PyTorch device.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
A sampling function that returns samples and the number of function evaluations during sampling.
|
95 |
+
"""
|
96 |
+
predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
|
97 |
+
rsde = sde.reverse(score_fn, probability_flow=True)
|
98 |
+
|
99 |
+
def denoise_update_fn(x):
|
100 |
+
vec_eps = torch.ones(x.shape[0], device=x.device) * eps
|
101 |
+
_, x = predictor.update_fn(x, vec_eps, y)
|
102 |
+
return x
|
103 |
+
|
104 |
+
def drift_fn(x, t, y):
|
105 |
+
"""Get the drift function of the reverse-time SDE."""
|
106 |
+
return rsde.sde(x, t, y)[0]
|
107 |
+
|
108 |
+
def ode_sampler(z=None, **kwargs):
|
109 |
+
"""The probability flow ODE sampler with black-box ODE solver.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
model: A score model.
|
113 |
+
z: If present, generate samples from latent code `z`.
|
114 |
+
Returns:
|
115 |
+
samples, number of function evaluations.
|
116 |
+
"""
|
117 |
+
with torch.no_grad():
|
118 |
+
# If not represent, sample the latent code from the prior distibution of the SDE.
|
119 |
+
x = sde.prior_sampling(y.shape, y).to(device)
|
120 |
+
|
121 |
+
def ode_func(t, x):
|
122 |
+
x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
|
123 |
+
vec_t = torch.ones(y.shape[0], device=x.device) * t
|
124 |
+
drift = drift_fn(x, vec_t, y)
|
125 |
+
return to_flattened_numpy(drift)
|
126 |
+
|
127 |
+
# Black-box ODE solver for the probability flow ODE
|
128 |
+
solution = integrate.solve_ivp(
|
129 |
+
ode_func, (sde.T, eps), to_flattened_numpy(x),
|
130 |
+
rtol=rtol, atol=atol, method=method, **kwargs
|
131 |
+
)
|
132 |
+
nfe = solution.nfev
|
133 |
+
x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
|
134 |
+
|
135 |
+
# Denoising is equivalent to running one predictor step without adding noise
|
136 |
+
if denoise:
|
137 |
+
x = denoise_update_fn(x)
|
138 |
+
|
139 |
+
if inverse_scaler is not None:
|
140 |
+
x = inverse_scaler(x)
|
141 |
+
return x, nfe
|
142 |
+
|
143 |
+
return ode_sampler
|
sgmse/sampling/correctors.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from sgmse import sdes
|
5 |
+
from sgmse.util.registry import Registry
|
6 |
+
|
7 |
+
|
8 |
+
CorrectorRegistry = Registry("Corrector")
|
9 |
+
|
10 |
+
|
11 |
+
class Corrector(abc.ABC):
|
12 |
+
"""The abstract class for a corrector algorithm."""
|
13 |
+
|
14 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
15 |
+
super().__init__()
|
16 |
+
self.rsde = sde.reverse(score_fn)
|
17 |
+
self.score_fn = score_fn
|
18 |
+
self.snr = snr
|
19 |
+
self.n_steps = n_steps
|
20 |
+
|
21 |
+
@abc.abstractmethod
|
22 |
+
def update_fn(self, x, t, *args):
|
23 |
+
"""One update of the corrector.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
x: A PyTorch tensor representing the current state
|
27 |
+
t: A PyTorch tensor representing the current time step.
|
28 |
+
*args: Possibly additional arguments, in particular `y` for OU processes
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
x: A PyTorch tensor of the next state.
|
32 |
+
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
33 |
+
"""
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
@CorrectorRegistry.register(name='langevin')
|
38 |
+
class LangevinCorrector(Corrector):
|
39 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
40 |
+
super().__init__(sde, score_fn, snr, n_steps)
|
41 |
+
self.score_fn = score_fn
|
42 |
+
self.n_steps = n_steps
|
43 |
+
self.snr = snr
|
44 |
+
|
45 |
+
def update_fn(self, x, t, *args):
|
46 |
+
target_snr = self.snr
|
47 |
+
for _ in range(self.n_steps):
|
48 |
+
grad = self.score_fn(x, t, *args)
|
49 |
+
noise = torch.randn_like(x)
|
50 |
+
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
51 |
+
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
52 |
+
step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
|
53 |
+
x_mean = x + step_size[:, None, None, None] * grad
|
54 |
+
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
|
55 |
+
|
56 |
+
return x, x_mean
|
57 |
+
|
58 |
+
|
59 |
+
@CorrectorRegistry.register(name='ald')
|
60 |
+
class AnnealedLangevinDynamics(Corrector):
|
61 |
+
"""The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
|
62 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
63 |
+
super().__init__(sde, score_fn, snr, n_steps)
|
64 |
+
if not isinstance(sde, (sdes.OUVESDE,)):
|
65 |
+
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
66 |
+
self.sde = sde
|
67 |
+
self.score_fn = score_fn
|
68 |
+
self.snr = snr
|
69 |
+
self.n_steps = n_steps
|
70 |
+
|
71 |
+
def update_fn(self, x, t, *args):
|
72 |
+
n_steps = self.n_steps
|
73 |
+
target_snr = self.snr
|
74 |
+
std = self.sde.marginal_prob(x, t, *args)[1]
|
75 |
+
|
76 |
+
for _ in range(n_steps):
|
77 |
+
grad = self.score_fn(x, t, *args)
|
78 |
+
noise = torch.randn_like(x)
|
79 |
+
step_size = (target_snr * std) ** 2 * 2
|
80 |
+
x_mean = x + step_size[:, None, None, None] * grad
|
81 |
+
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
|
82 |
+
|
83 |
+
return x, x_mean
|
84 |
+
|
85 |
+
|
86 |
+
@CorrectorRegistry.register(name='none')
|
87 |
+
class NoneCorrector(Corrector):
|
88 |
+
"""An empty corrector that does nothing."""
|
89 |
+
|
90 |
+
def __init__(self, *args, **kwargs):
|
91 |
+
self.snr = 0
|
92 |
+
self.n_steps = 0
|
93 |
+
pass
|
94 |
+
|
95 |
+
def update_fn(self, x, t, *args):
|
96 |
+
return x, x
|
sgmse/sampling/predictors.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from sgmse.util.registry import Registry
|
7 |
+
|
8 |
+
|
9 |
+
PredictorRegistry = Registry("Predictor")
|
10 |
+
|
11 |
+
|
12 |
+
class Predictor(abc.ABC):
|
13 |
+
"""The abstract class for a predictor algorithm."""
|
14 |
+
|
15 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
16 |
+
super().__init__()
|
17 |
+
self.sde = sde
|
18 |
+
self.rsde = sde.reverse(score_fn)
|
19 |
+
self.score_fn = score_fn
|
20 |
+
self.probability_flow = probability_flow
|
21 |
+
|
22 |
+
@abc.abstractmethod
|
23 |
+
def update_fn(self, x, t, *args):
|
24 |
+
"""One update of the predictor.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
x: A PyTorch tensor representing the current state
|
28 |
+
t: A Pytorch tensor representing the current time step.
|
29 |
+
*args: Possibly additional arguments, in particular `y` for OU processes
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
x: A PyTorch tensor of the next state.
|
33 |
+
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
34 |
+
"""
|
35 |
+
pass
|
36 |
+
|
37 |
+
def debug_update_fn(self, x, t, *args):
|
38 |
+
raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
|
39 |
+
|
40 |
+
|
41 |
+
@PredictorRegistry.register('euler_maruyama')
|
42 |
+
class EulerMaruyamaPredictor(Predictor):
|
43 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
44 |
+
super().__init__(sde, score_fn, probability_flow=probability_flow)
|
45 |
+
|
46 |
+
def update_fn(self, x, t, *args):
|
47 |
+
dt = -1. / self.rsde.N
|
48 |
+
z = torch.randn_like(x)
|
49 |
+
f, g = self.rsde.sde(x, t, *args)
|
50 |
+
x_mean = x + f * dt
|
51 |
+
x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
|
52 |
+
return x, x_mean
|
53 |
+
|
54 |
+
|
55 |
+
@PredictorRegistry.register('reverse_diffusion')
|
56 |
+
class ReverseDiffusionPredictor(Predictor):
|
57 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
58 |
+
super().__init__(sde, score_fn, probability_flow=probability_flow)
|
59 |
+
|
60 |
+
def update_fn(self, x, t, y, stepsize):
|
61 |
+
f, g = self.rsde.discretize(x, t, y, stepsize)
|
62 |
+
z = torch.randn_like(x)
|
63 |
+
x_mean = x - f
|
64 |
+
x = x_mean + g[:, None, None, None] * z
|
65 |
+
return x, x_mean
|
66 |
+
|
67 |
+
|
68 |
+
@PredictorRegistry.register('none')
|
69 |
+
class NonePredictor(Predictor):
|
70 |
+
"""An empty predictor that does nothing."""
|
71 |
+
|
72 |
+
def __init__(self, *args, **kwargs):
|
73 |
+
pass
|
74 |
+
|
75 |
+
def update_fn(self, x, t, *args):
|
76 |
+
return x, x
|
sgmse/sdes.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
|
3 |
+
|
4 |
+
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
|
5 |
+
"""
|
6 |
+
import abc
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from sgmse.util.tensors import batch_broadcast
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from sgmse.util.registry import Registry
|
14 |
+
|
15 |
+
|
16 |
+
SDERegistry = Registry("SDE")
|
17 |
+
|
18 |
+
|
19 |
+
class SDE(abc.ABC):
|
20 |
+
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
|
21 |
+
|
22 |
+
def __init__(self, N):
|
23 |
+
"""Construct an SDE.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
N: number of discretization time steps.
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
self.N = N
|
30 |
+
|
31 |
+
@property
|
32 |
+
@abc.abstractmethod
|
33 |
+
def T(self):
|
34 |
+
"""End time of the SDE."""
|
35 |
+
pass
|
36 |
+
|
37 |
+
@abc.abstractmethod
|
38 |
+
def sde(self, x, t, *args):
|
39 |
+
pass
|
40 |
+
|
41 |
+
@abc.abstractmethod
|
42 |
+
def marginal_prob(self, x, t, *args):
|
43 |
+
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
|
44 |
+
pass
|
45 |
+
|
46 |
+
@abc.abstractmethod
|
47 |
+
def prior_sampling(self, shape, *args):
|
48 |
+
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
|
49 |
+
pass
|
50 |
+
|
51 |
+
@abc.abstractmethod
|
52 |
+
def prior_logp(self, z):
|
53 |
+
"""Compute log-density of the prior distribution.
|
54 |
+
|
55 |
+
Useful for computing the log-likelihood via probability flow ODE.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
z: latent code
|
59 |
+
Returns:
|
60 |
+
log probability density
|
61 |
+
"""
|
62 |
+
pass
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
@abc.abstractmethod
|
66 |
+
def add_argparse_args(parent_parser):
|
67 |
+
"""
|
68 |
+
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
|
69 |
+
"""
|
70 |
+
pass
|
71 |
+
|
72 |
+
def discretize(self, x, t, y, stepsize):
|
73 |
+
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
|
74 |
+
|
75 |
+
Useful for reverse diffusion sampling and probabiliy flow sampling.
|
76 |
+
Defaults to Euler-Maruyama discretization.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
x: a torch tensor
|
80 |
+
t: a torch float representing the time step (from 0 to `self.T`)
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
f, G
|
84 |
+
"""
|
85 |
+
dt = stepsize
|
86 |
+
drift, diffusion = self.sde(x, t, y)
|
87 |
+
f = drift * dt
|
88 |
+
G = diffusion * torch.sqrt(dt)
|
89 |
+
return f, G
|
90 |
+
|
91 |
+
def reverse(oself, score_model, probability_flow=False):
|
92 |
+
"""Create the reverse-time SDE/ODE.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
score_model: A function that takes x, t and y and returns the score.
|
96 |
+
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
|
97 |
+
"""
|
98 |
+
N = oself.N
|
99 |
+
T = oself.T
|
100 |
+
sde_fn = oself.sde
|
101 |
+
discretize_fn = oself.discretize
|
102 |
+
|
103 |
+
# Build the class for reverse-time SDE.
|
104 |
+
class RSDE(oself.__class__):
|
105 |
+
def __init__(self):
|
106 |
+
self.N = N
|
107 |
+
self.probability_flow = probability_flow
|
108 |
+
|
109 |
+
@property
|
110 |
+
def T(self):
|
111 |
+
return T
|
112 |
+
|
113 |
+
def sde(self, x, t, *args):
|
114 |
+
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
|
115 |
+
rsde_parts = self.rsde_parts(x, t, *args)
|
116 |
+
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
|
117 |
+
return total_drift, diffusion
|
118 |
+
|
119 |
+
def rsde_parts(self, x, t, *args):
|
120 |
+
sde_drift, sde_diffusion = sde_fn(x, t, *args)
|
121 |
+
score = score_model(x, t, *args)
|
122 |
+
score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
|
123 |
+
diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
|
124 |
+
total_drift = sde_drift + score_drift
|
125 |
+
return {
|
126 |
+
'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
|
127 |
+
'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
|
128 |
+
}
|
129 |
+
|
130 |
+
def discretize(self, x, t, y, stepsize):
|
131 |
+
"""Create discretized iteration rules for the reverse diffusion sampler."""
|
132 |
+
f, G = discretize_fn(x, t, y, stepsize)
|
133 |
+
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
|
134 |
+
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
135 |
+
return rev_f, rev_G
|
136 |
+
|
137 |
+
return RSDE()
|
138 |
+
|
139 |
+
@abc.abstractmethod
|
140 |
+
def copy(self):
|
141 |
+
pass
|
142 |
+
|
143 |
+
|
144 |
+
@SDERegistry.register("ouve")
|
145 |
+
class OUVESDE(SDE):
|
146 |
+
@staticmethod
|
147 |
+
def add_argparse_args(parser):
|
148 |
+
parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default")
|
149 |
+
parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
|
150 |
+
parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
|
151 |
+
parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
|
152 |
+
return parser
|
153 |
+
|
154 |
+
def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
|
155 |
+
"""Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
|
156 |
+
|
157 |
+
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
|
158 |
+
to the methods which require it (e.g., `sde` or `marginal_prob`).
|
159 |
+
|
160 |
+
dx = -theta (y-x) dt + sigma(t) dw
|
161 |
+
|
162 |
+
with
|
163 |
+
|
164 |
+
sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
|
165 |
+
|
166 |
+
Args:
|
167 |
+
theta: stiffness parameter.
|
168 |
+
sigma_min: smallest sigma.
|
169 |
+
sigma_max: largest sigma.
|
170 |
+
N: number of discretization steps
|
171 |
+
"""
|
172 |
+
super().__init__(N)
|
173 |
+
self.theta = theta
|
174 |
+
self.sigma_min = sigma_min
|
175 |
+
self.sigma_max = sigma_max
|
176 |
+
self.logsig = np.log(self.sigma_max / self.sigma_min)
|
177 |
+
self.N = N
|
178 |
+
|
179 |
+
def copy(self):
|
180 |
+
return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
|
181 |
+
|
182 |
+
@property
|
183 |
+
def T(self):
|
184 |
+
return 1
|
185 |
+
|
186 |
+
def sde(self, x, t, y):
|
187 |
+
drift = self.theta * (y - x)
|
188 |
+
# the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
|
189 |
+
# standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
|
190 |
+
# with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
|
191 |
+
# unless this sqrt(2*logsig) factor is included.
|
192 |
+
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
193 |
+
diffusion = sigma * np.sqrt(2 * self.logsig)
|
194 |
+
return drift, diffusion
|
195 |
+
|
196 |
+
def _mean(self, x0, t, y):
|
197 |
+
theta = self.theta
|
198 |
+
exp_interp = torch.exp(-theta * t)[:, None, None, None]
|
199 |
+
return exp_interp * x0 + (1 - exp_interp) * y
|
200 |
+
|
201 |
+
def alpha(self, t):
|
202 |
+
return torch.exp(-self.theta * t)
|
203 |
+
|
204 |
+
def _std(self, t):
|
205 |
+
# This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
|
206 |
+
sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
|
207 |
+
# could maybe replace the two torch.exp(... * t) terms here by cached values **t
|
208 |
+
return torch.sqrt(
|
209 |
+
(
|
210 |
+
sigma_min**2
|
211 |
+
* torch.exp(-2 * theta * t)
|
212 |
+
* (torch.exp(2 * (theta + logsig) * t) - 1)
|
213 |
+
* logsig
|
214 |
+
)
|
215 |
+
/
|
216 |
+
(theta + logsig)
|
217 |
+
)
|
218 |
+
|
219 |
+
def marginal_prob(self, x0, t, y):
|
220 |
+
return self._mean(x0, t, y), self._std(t)
|
221 |
+
|
222 |
+
def prior_sampling(self, shape, y):
|
223 |
+
if shape != y.shape:
|
224 |
+
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
|
225 |
+
std = self._std(torch.ones((y.shape[0],), device=y.device))
|
226 |
+
x_T = y + torch.randn_like(y) * std[:, None, None, None]
|
227 |
+
return x_T
|
228 |
+
|
229 |
+
def prior_logp(self, z):
|
230 |
+
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
|
231 |
+
|
232 |
+
|
233 |
+
@SDERegistry.register("ouvp")
|
234 |
+
class OUVPSDE(SDE):
|
235 |
+
# !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
|
236 |
+
@staticmethod
|
237 |
+
def add_argparse_args(parser):
|
238 |
+
parser.add_argument("--sde-n", type=int, default=1000,
|
239 |
+
help="The number of timesteps in the SDE discretization. 1000 by default")
|
240 |
+
parser.add_argument("--beta-min", type=float, required=True,
|
241 |
+
help="The minimum beta to use.")
|
242 |
+
parser.add_argument("--beta-max", type=float, required=True,
|
243 |
+
help="The maximum beta to use.")
|
244 |
+
parser.add_argument("--stiffness", type=float, default=1,
|
245 |
+
help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
|
246 |
+
return parser
|
247 |
+
|
248 |
+
def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
|
249 |
+
"""
|
250 |
+
!!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
|
251 |
+
|
252 |
+
Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
|
253 |
+
|
254 |
+
dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
|
255 |
+
|
256 |
+
with
|
257 |
+
|
258 |
+
beta(t) = beta_min + t(beta_max - beta_min)
|
259 |
+
|
260 |
+
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
|
261 |
+
to the methods which require it (e.g., `sde` or `marginal_prob`).
|
262 |
+
|
263 |
+
Args:
|
264 |
+
beta_min: smallest sigma.
|
265 |
+
beta_max: largest sigma.
|
266 |
+
stiffness: stiffness factor of the drift. 1 by default.
|
267 |
+
N: number of discretization steps
|
268 |
+
"""
|
269 |
+
super().__init__(N)
|
270 |
+
self.beta_min = beta_min
|
271 |
+
self.beta_max = beta_max
|
272 |
+
self.stiffness = stiffness
|
273 |
+
self.N = N
|
274 |
+
|
275 |
+
def copy(self):
|
276 |
+
return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
|
277 |
+
|
278 |
+
@property
|
279 |
+
def T(self):
|
280 |
+
return 1
|
281 |
+
|
282 |
+
def _beta(self, t):
|
283 |
+
return self.beta_min + t * (self.beta_max - self.beta_min)
|
284 |
+
|
285 |
+
def sde(self, x, t, y):
|
286 |
+
drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
|
287 |
+
diffusion = torch.sqrt(self._beta(t))
|
288 |
+
return drift, diffusion
|
289 |
+
|
290 |
+
def _mean(self, x0, t, y):
|
291 |
+
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
|
292 |
+
x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
|
293 |
+
return y + x0y_fac * (x0 - y)
|
294 |
+
|
295 |
+
def _std(self, t):
|
296 |
+
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
|
297 |
+
return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
|
298 |
+
|
299 |
+
def marginal_prob(self, x0, t, y):
|
300 |
+
return self._mean(x0, t, y), self._std(t)
|
301 |
+
|
302 |
+
def prior_sampling(self, shape, y):
|
303 |
+
if shape != y.shape:
|
304 |
+
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
|
305 |
+
std = self._std(torch.ones((y.shape[0],), device=y.device))
|
306 |
+
x_T = y + torch.randn_like(y) * std[:, None, None, None]
|
307 |
+
return x_T
|
308 |
+
|
309 |
+
def prior_logp(self, z):
|
310 |
+
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
|
sgmse/util/inference.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchaudio import load
|
3 |
+
|
4 |
+
from pesq import pesq
|
5 |
+
from pystoi import stoi
|
6 |
+
|
7 |
+
from .other import si_sdr, pad_spec
|
8 |
+
|
9 |
+
# Settings
|
10 |
+
sr = 16000
|
11 |
+
snr = 0.5
|
12 |
+
N = 30
|
13 |
+
corrector_steps = 1
|
14 |
+
|
15 |
+
|
16 |
+
def evaluate_model(model, num_eval_files):
|
17 |
+
|
18 |
+
clean_files = model.data_module.valid_set.clean_files
|
19 |
+
noisy_files = model.data_module.valid_set.noisy_files
|
20 |
+
|
21 |
+
# Select test files uniformly accros validation files
|
22 |
+
total_num_files = len(clean_files)
|
23 |
+
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
|
24 |
+
clean_files = list(clean_files[i] for i in indices)
|
25 |
+
noisy_files = list(noisy_files[i] for i in indices)
|
26 |
+
|
27 |
+
_pesq = 0
|
28 |
+
_si_sdr = 0
|
29 |
+
_estoi = 0
|
30 |
+
# iterate over files
|
31 |
+
for (clean_file, noisy_file) in zip(clean_files, noisy_files):
|
32 |
+
# Load wavs
|
33 |
+
x, _ = load(clean_file)
|
34 |
+
y, _ = load(noisy_file)
|
35 |
+
T_orig = x.size(1)
|
36 |
+
|
37 |
+
# Normalize per utterance
|
38 |
+
norm_factor = y.abs().max()
|
39 |
+
y = y / norm_factor
|
40 |
+
|
41 |
+
# Prepare DNN input
|
42 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
|
43 |
+
Y = pad_spec(Y)
|
44 |
+
y = y * norm_factor
|
45 |
+
|
46 |
+
# Reverse sampling
|
47 |
+
sampler = model.get_pc_sampler(
|
48 |
+
'reverse_diffusion', 'ald', Y.cuda(), N=N,
|
49 |
+
corrector_steps=corrector_steps, snr=snr)
|
50 |
+
sample, _ = sampler()
|
51 |
+
|
52 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
53 |
+
x_hat = x_hat * norm_factor
|
54 |
+
|
55 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
56 |
+
x = x.squeeze().cpu().numpy()
|
57 |
+
y = y.squeeze().cpu().numpy()
|
58 |
+
|
59 |
+
_si_sdr += si_sdr(x, x_hat)
|
60 |
+
_pesq += pesq(sr, x, x_hat, 'wb')
|
61 |
+
_estoi += stoi(x, x_hat, sr, extended=True)
|
62 |
+
|
63 |
+
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
|
64 |
+
|
sgmse/util/other.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import scipy.stats
|
5 |
+
from scipy.signal import butter, sosfilt
|
6 |
+
|
7 |
+
from pesq import pesq
|
8 |
+
from pystoi import stoi
|
9 |
+
|
10 |
+
|
11 |
+
def si_sdr_components(s_hat, s, n):
|
12 |
+
# s_target
|
13 |
+
alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
|
14 |
+
s_target = alpha_s * s
|
15 |
+
|
16 |
+
# e_noise
|
17 |
+
alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
|
18 |
+
e_noise = alpha_n * n
|
19 |
+
|
20 |
+
# e_art
|
21 |
+
e_art = s_hat - s_target - e_noise
|
22 |
+
|
23 |
+
return s_target, e_noise, e_art
|
24 |
+
|
25 |
+
def energy_ratios(s_hat, s, n):
|
26 |
+
s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
|
27 |
+
|
28 |
+
si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
|
29 |
+
si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
|
30 |
+
si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
|
31 |
+
|
32 |
+
return si_sdr, si_sir, si_sar
|
33 |
+
|
34 |
+
def mean_conf_int(data, confidence=0.95):
|
35 |
+
a = 1.0 * np.array(data)
|
36 |
+
n = len(a)
|
37 |
+
m, se = np.mean(a), scipy.stats.sem(a)
|
38 |
+
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
|
39 |
+
return m, h
|
40 |
+
|
41 |
+
class Method():
|
42 |
+
def __init__(self, name, base_dir, metrics):
|
43 |
+
self.name = name
|
44 |
+
self.base_dir = base_dir
|
45 |
+
self.metrics = {}
|
46 |
+
|
47 |
+
for i in range(len(metrics)):
|
48 |
+
metric = metrics[i]
|
49 |
+
value = []
|
50 |
+
self.metrics[metric] = value
|
51 |
+
|
52 |
+
def append(self, matric, value):
|
53 |
+
self.metrics[matric].append(value)
|
54 |
+
|
55 |
+
def get_mean_ci(self, metric):
|
56 |
+
return mean_conf_int(np.array(self.metrics[metric]))
|
57 |
+
|
58 |
+
def hp_filter(signal, cut_off=80, order=10, sr=16000):
|
59 |
+
factor = cut_off /sr * 2
|
60 |
+
sos = butter(order, factor, 'hp', output='sos')
|
61 |
+
filtered = sosfilt(sos, signal)
|
62 |
+
return filtered
|
63 |
+
|
64 |
+
def si_sdr(s, s_hat):
|
65 |
+
alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
|
66 |
+
sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
|
67 |
+
alpha*s - s_hat)**2)
|
68 |
+
return sdr
|
69 |
+
|
70 |
+
def snr_dB(s,n):
|
71 |
+
s_power = 1/len(s)*np.sum(s**2)
|
72 |
+
n_power = 1/len(n)*np.sum(n**2)
|
73 |
+
snr_dB = 10*np.log10(s_power/n_power)
|
74 |
+
return snr_dB
|
75 |
+
|
76 |
+
def pad_spec(Y, mode="zero_pad"):
|
77 |
+
T = Y.size(3)
|
78 |
+
if T%64 !=0:
|
79 |
+
num_pad = 64-T%64
|
80 |
+
else:
|
81 |
+
num_pad = 0
|
82 |
+
if mode == "zero_pad":
|
83 |
+
pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
|
84 |
+
elif mode == "reflection":
|
85 |
+
pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
|
86 |
+
elif mode == "replication":
|
87 |
+
pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
|
88 |
+
else:
|
89 |
+
raise NotImplementedError("This function hasn't been implemented yet.")
|
90 |
+
return pad2d(Y)
|
91 |
+
|
92 |
+
def ensure_dir(file_path):
|
93 |
+
directory = file_path
|
94 |
+
if not os.path.exists(directory):
|
95 |
+
os.makedirs(directory)
|
96 |
+
|
97 |
+
|
98 |
+
def print_metrics(x, y, x_hat_list, labels, sr=16000):
|
99 |
+
_si_sdr_mix = si_sdr(x, y)
|
100 |
+
_pesq_mix = pesq(sr, x, y, 'wb')
|
101 |
+
_estoi_mix = stoi(x, y, sr, extended=True)
|
102 |
+
print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
|
103 |
+
for i, x_hat in enumerate(x_hat_list):
|
104 |
+
_si_sdr = si_sdr(x, x_hat)
|
105 |
+
_pesq = pesq(sr, x, x_hat, 'wb')
|
106 |
+
_estoi = stoi(x, x_hat, sr, extended=True)
|
107 |
+
print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
|
108 |
+
|
109 |
+
def mean_std(data):
|
110 |
+
data = data[~np.isnan(data)]
|
111 |
+
mean = np.mean(data)
|
112 |
+
std = np.std(data)
|
113 |
+
return mean, std
|
114 |
+
|
115 |
+
def print_mean_std(data, decimal=2):
|
116 |
+
data = np.array(data)
|
117 |
+
data = data[~np.isnan(data)]
|
118 |
+
mean = np.mean(data)
|
119 |
+
std = np.std(data)
|
120 |
+
if decimal == 2:
|
121 |
+
string = f'{mean:.2f} ± {std:.2f}'
|
122 |
+
elif decimal == 1:
|
123 |
+
string = f'{mean:.1f} ± {std:.1f}'
|
124 |
+
return string
|
125 |
+
|
126 |
+
def set_torch_cuda_arch_list():
|
127 |
+
if not torch.cuda.is_available():
|
128 |
+
print("CUDA is not available. No GPUs found.")
|
129 |
+
return
|
130 |
+
|
131 |
+
num_gpus = torch.cuda.device_count()
|
132 |
+
compute_capabilities = []
|
133 |
+
|
134 |
+
for i in range(num_gpus):
|
135 |
+
cc_major, cc_minor = torch.cuda.get_device_capability(i)
|
136 |
+
cc = f"{cc_major}.{cc_minor}"
|
137 |
+
compute_capabilities.append(cc)
|
138 |
+
|
139 |
+
cc_string = ";".join(compute_capabilities)
|
140 |
+
os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
|
141 |
+
print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
|
sgmse/util/registry.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
|
5 |
+
class Registry:
|
6 |
+
def __init__(self, managed_thing: str):
|
7 |
+
"""
|
8 |
+
Create a new registry.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
managed_thing: A string describing what type of thing is managed by this registry. Will be used for
|
12 |
+
warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
|
13 |
+
"""
|
14 |
+
self.managed_thing = managed_thing
|
15 |
+
self._registry = {}
|
16 |
+
|
17 |
+
def register(self, name: str) -> Callable:
|
18 |
+
def inner_wrapper(wrapped_class) -> Callable:
|
19 |
+
if name in self._registry:
|
20 |
+
warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
|
21 |
+
self._registry[name] = wrapped_class
|
22 |
+
return wrapped_class
|
23 |
+
return inner_wrapper
|
24 |
+
|
25 |
+
def get_by_name(self, name: str):
|
26 |
+
"""Get a managed thing by name."""
|
27 |
+
if name in self._registry:
|
28 |
+
return self._registry[name]
|
29 |
+
else:
|
30 |
+
raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
|
31 |
+
|
32 |
+
def get_all_names(self):
|
33 |
+
"""Get the list of things' names registered to this registry."""
|
34 |
+
return list(self._registry.keys())
|
sgmse/util/tensors.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def batch_broadcast(a, x):
|
2 |
+
"""Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
|
3 |
+
|
4 |
+
if len(a.shape) != 1:
|
5 |
+
a = a.squeeze()
|
6 |
+
if len(a.shape) != 1:
|
7 |
+
raise ValueError(
|
8 |
+
f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
|
9 |
+
)
|
10 |
+
|
11 |
+
if a.shape[0] != x.shape[0] and a.shape[0] != 1:
|
12 |
+
raise ValueError(
|
13 |
+
f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
|
14 |
+
|
15 |
+
out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
|
16 |
+
return out
|
train.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
import argparse
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from pytorch_lightning.loggers import WandbLogger
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
+
from os.path import join
|
9 |
+
|
10 |
+
# Set CUDA architecture list
|
11 |
+
from sgmse.util.other import set_torch_cuda_arch_list
|
12 |
+
set_torch_cuda_arch_list()
|
13 |
+
|
14 |
+
from sgmse.backbones.shared import BackboneRegistry
|
15 |
+
from sgmse.data_module import SpecsDataModule
|
16 |
+
from sgmse.sdes import SDERegistry
|
17 |
+
from sgmse.model import ScoreModel
|
18 |
+
|
19 |
+
|
20 |
+
def get_argparse_groups(parser):
|
21 |
+
groups = {}
|
22 |
+
for group in parser._action_groups:
|
23 |
+
group_dict = { a.dest: getattr(args, a.dest, None) for a in group._group_actions }
|
24 |
+
groups[group.title] = argparse.Namespace(**group_dict)
|
25 |
+
return groups
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
# throwaway parser for dynamic args - see https://stackoverflow.com/a/25320537/3090225
|
30 |
+
base_parser = ArgumentParser(add_help=False)
|
31 |
+
parser = ArgumentParser()
|
32 |
+
for parser_ in (base_parser, parser):
|
33 |
+
parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp")
|
34 |
+
parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve")
|
35 |
+
parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
|
36 |
+
parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
|
37 |
+
parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
|
38 |
+
parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
|
39 |
+
|
40 |
+
temp_args, _ = base_parser.parse_known_args()
|
41 |
+
|
42 |
+
# Add specific args for ScoreModel, pl.Trainer, the SDE class and backbone DNN class
|
43 |
+
backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone)
|
44 |
+
sde_class = SDERegistry.get_by_name(temp_args.sde)
|
45 |
+
trainer_parser = parser.add_argument_group("Trainer", description="Lightning Trainer")
|
46 |
+
trainer_parser.add_argument("--accelerator", type=str, default="gpu", help="Supports passing different accelerator types.")
|
47 |
+
trainer_parser.add_argument("--devices", default="auto", help="How many gpus to use.")
|
48 |
+
trainer_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients.")
|
49 |
+
|
50 |
+
ScoreModel.add_argparse_args(
|
51 |
+
parser.add_argument_group("ScoreModel", description=ScoreModel.__name__))
|
52 |
+
sde_class.add_argparse_args(
|
53 |
+
parser.add_argument_group("SDE", description=sde_class.__name__))
|
54 |
+
backbone_cls.add_argparse_args(
|
55 |
+
parser.add_argument_group("Backbone", description=backbone_cls.__name__))
|
56 |
+
# Add data module args
|
57 |
+
data_module_cls = SpecsDataModule
|
58 |
+
data_module_cls.add_argparse_args(
|
59 |
+
parser.add_argument_group("DataModule", description=data_module_cls.__name__))
|
60 |
+
# Parse args and separate into groups
|
61 |
+
args = parser.parse_args()
|
62 |
+
arg_groups = get_argparse_groups(parser)
|
63 |
+
|
64 |
+
# Initialize logger, trainer, model, datamodule
|
65 |
+
model = ScoreModel(
|
66 |
+
backbone=args.backbone, sde=args.sde, data_module_cls=data_module_cls,
|
67 |
+
**{
|
68 |
+
**vars(arg_groups['ScoreModel']),
|
69 |
+
**vars(arg_groups['SDE']),
|
70 |
+
**vars(arg_groups['Backbone']),
|
71 |
+
**vars(arg_groups['DataModule'])
|
72 |
+
}
|
73 |
+
)
|
74 |
+
|
75 |
+
# Set up logger configuration
|
76 |
+
if args.nolog:
|
77 |
+
logger = None
|
78 |
+
else:
|
79 |
+
logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs", name=args.wandb_name)
|
80 |
+
logger.experiment.log_code(".")
|
81 |
+
|
82 |
+
# Set up callbacks for logger
|
83 |
+
if logger != None:
|
84 |
+
callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
|
85 |
+
if args.num_eval_files:
|
86 |
+
checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
|
87 |
+
save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
|
88 |
+
checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
|
89 |
+
save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
|
90 |
+
callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
|
91 |
+
else:
|
92 |
+
callbacks = None
|
93 |
+
|
94 |
+
# Initialize the Trainer and the DataModule
|
95 |
+
trainer = pl.Trainer(
|
96 |
+
**vars(arg_groups['Trainer']),
|
97 |
+
strategy="ddp", logger=logger,
|
98 |
+
log_every_n_steps=10, num_sanity_val_steps=0,
|
99 |
+
callbacks=callbacks
|
100 |
+
)
|
101 |
+
|
102 |
+
# Train model
|
103 |
+
trainer.fit(model, ckpt_path=args.ckpt)
|