Shokoufehhh commited on
Commit
05b4fca
·
verified ·
1 Parent(s): 2bfe98e

Upload 40 files

Browse files
.gitignore ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,macos,linux,emacs,windows,jupyternotebooks
3
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,macos,linux,emacs,windows,jupyternotebooks
4
+
5
+ ### Emacs ###
6
+ # -*- mode: gitignore; -*-
7
+ *~
8
+ \#*\#
9
+ /.emacs.desktop
10
+ /.emacs.desktop.lock
11
+ *.elc
12
+ auto-save-list
13
+ tramp
14
+ .\#*
15
+
16
+ # Org-mode
17
+ .org-id-locations
18
+ *_archive
19
+
20
+ # flymake-mode
21
+ *_flymake.*
22
+
23
+ # figures
24
+ figures/
25
+ listen/
26
+ enhanced/
27
+ checkpoints/
28
+ baselines/
29
+ __paths__.py
30
+ *.csv
31
+ *.pdf
32
+ *.jpg
33
+ *.png
34
+ *.html
35
+ mushra/
36
+
37
+ # eshell files
38
+ /eshell/history
39
+ /eshell/lastdir
40
+
41
+ # elpa packages
42
+ /elpa/
43
+
44
+ # reftex files
45
+ *.rel
46
+
47
+ # AUCTeX auto folder
48
+ /auto/
49
+
50
+ # cask packages
51
+ .cask/
52
+ dist/
53
+
54
+ # Flycheck
55
+ flycheck_*.el
56
+
57
+ # server auth directory
58
+ /server/
59
+
60
+ # projectiles files
61
+ .projectile
62
+
63
+ # directory configuration
64
+ .dir-locals.el
65
+
66
+ # network security
67
+ /network-security.data
68
+
69
+
70
+ ### JupyterNotebooks ###
71
+ # gitignore template for Jupyter Notebooks
72
+ # website: http://jupyter.org/
73
+
74
+ .ipynb_checkpoints
75
+ */.ipynb_checkpoints/*
76
+
77
+ # IPython
78
+ profile_default/
79
+ ipython_config.py
80
+
81
+ # Remove previous ipynb_checkpoints
82
+ # git rm -r .ipynb_checkpoints/
83
+
84
+ ### Linux ###
85
+
86
+ # temporary files which can be created if a process still has a handle open of a deleted file
87
+ .fuse_hidden*
88
+
89
+ # KDE directory preferences
90
+ .directory
91
+
92
+ # Linux trash folder which might appear on any partition or disk
93
+ .Trash-*
94
+
95
+ # .nfs files are created when an open file is removed but is still being accessed
96
+ .nfs*
97
+
98
+ ### macOS ###
99
+ # General
100
+ .DS_Store
101
+ .AppleDouble
102
+ .LSOverride
103
+
104
+ # Icon must end with two \r
105
+ Icon
106
+
107
+
108
+ # Thumbnails
109
+ ._*
110
+
111
+ # Files that might appear in the root of a volume
112
+ .DocumentRevisions-V100
113
+ .fseventsd
114
+ .Spotlight-V100
115
+ .TemporaryItems
116
+ .Trashes
117
+ .VolumeIcon.icns
118
+ .com.apple.timemachine.donotpresent
119
+
120
+ # Directories potentially created on remote AFP share
121
+ .AppleDB
122
+ .AppleDesktop
123
+ Network Trash Folder
124
+ Temporary Items
125
+ .apdisk
126
+
127
+ ### Python ###
128
+ # Byte-compiled / optimized / DLL files
129
+ __pycache__/
130
+ *.py[cod]
131
+ *$py.class
132
+
133
+ # C extensions
134
+ *.so
135
+
136
+ # Distribution / packaging
137
+ .Python
138
+ build/
139
+ develop-eggs/
140
+ downloads/
141
+ eggs/
142
+ .eggs/
143
+ lib/
144
+ lib64/
145
+ parts/
146
+ sdist/
147
+ var/
148
+ wheels/
149
+ share/python-wheels/
150
+ *.egg-info/
151
+ .installed.cfg
152
+ *.egg
153
+ MANIFEST
154
+
155
+ # PyInstaller
156
+ # Usually these files are written by a python script from a template
157
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
158
+ *.manifest
159
+ *.spec
160
+
161
+ # Installer logs
162
+ pip-log.txt
163
+ pip-delete-this-directory.txt
164
+
165
+ # Unit test / coverage reports
166
+ htmlcov/
167
+ .tox/
168
+ .nox/
169
+ .coverage
170
+ .coverage.*
171
+ .cache
172
+ nosetests.xml
173
+ coverage.xml
174
+ *.cover
175
+ *.py,cover
176
+ .hypothesis/
177
+ .pytest_cache/
178
+ cover/
179
+
180
+ # Translations
181
+ *.mo
182
+ *.pot
183
+
184
+ # Django stuff:
185
+ *.log
186
+ local_settings.py
187
+ db.sqlite3
188
+ db.sqlite3-journal
189
+
190
+ # Flask stuff:
191
+ instance/
192
+ .webassets-cache
193
+
194
+ # Scrapy stuff:
195
+ .scrapy
196
+
197
+ # Sphinx documentation
198
+ docs/_build/
199
+
200
+ # PyBuilder
201
+ .pybuilder/
202
+ target/
203
+
204
+ # Jupyter Notebook
205
+
206
+ # IPython
207
+
208
+ # pyenv
209
+ # For a library or package, you might want to ignore these files since the code is
210
+ # intended to run in multiple environments; otherwise, check them in:
211
+ # .python-version
212
+
213
+ # pipenv
214
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
215
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
216
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
217
+ # install all needed dependencies.
218
+ #Pipfile.lock
219
+
220
+ # poetry
221
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
222
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
223
+ # commonly ignored for libraries.
224
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
225
+ #poetry.lock
226
+
227
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
228
+ __pypackages__/
229
+
230
+ # Celery stuff
231
+ celerybeat-schedule
232
+ celerybeat.pid
233
+
234
+ # SageMath parsed files
235
+ *.sage.py
236
+
237
+ # Environments
238
+ .env
239
+ .venv
240
+ env/
241
+ venv/
242
+ ENV/
243
+ env.bak/
244
+ venv.bak/
245
+
246
+ # Spyder project settings
247
+ .spyderproject
248
+ .spyproject
249
+
250
+ # Rope project settings
251
+ .ropeproject
252
+
253
+ # mkdocs documentation
254
+ /site
255
+
256
+ # mypy
257
+ .mypy_cache/
258
+ .dmypy.json
259
+ dmypy.json
260
+
261
+ # Pyre type checker
262
+ .pyre/
263
+
264
+ # pytype static type analyzer
265
+ .pytype/
266
+
267
+ # Cython debug symbols
268
+ cython_debug/
269
+
270
+ # PyCharm
271
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
272
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
273
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
274
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
275
+ #.idea/
276
+
277
+ ### VisualStudioCode ###
278
+ .vscode/
279
+ .vscode/*
280
+ !.vscode/settings.json
281
+ !.vscode/tasks.json
282
+ !.vscode/launch.json
283
+ !.vscode/extensions.json
284
+ !.vscode/*.code-snippets
285
+
286
+ # Local History for Visual Studio Code
287
+ .history/
288
+
289
+ # Built Visual Studio Code Extensions
290
+ *.vsix
291
+
292
+ ### VisualStudioCode Patch ###
293
+ # Ignore all local history of files
294
+ .history
295
+ .ionide
296
+
297
+ # Support for Project snippet scope
298
+
299
+ ### Windows ###
300
+ # Windows thumbnail cache files
301
+ Thumbs.db
302
+ Thumbs.db:encryptable
303
+ ehthumbs.db
304
+ ehthumbs_vista.db
305
+
306
+ # Dump file
307
+ *.stackdump
308
+
309
+ # Folder config file
310
+ [Dd]esktop.ini
311
+
312
+ # Recycle Bin used on file shares
313
+ $RECYCLE.BIN/
314
+
315
+ # Windows Installer files
316
+ *.cab
317
+ *.msi
318
+ *.msix
319
+ *.msm
320
+ *.msp
321
+
322
+ # Windows shortcuts
323
+ *.lnk
324
+
325
+ # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,macos,linux,emacs,windows,jupyternotebooks
326
+
327
+
328
+
329
+ # Custom ignores:
330
+
331
+ ## Logs from W&B and PyTorch Lightning
332
+ /wandb
333
+ /lightning_logs
334
+ /logs
335
+ /jobs
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Signal Processing (SP), Universität Hamburg
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - speech-enhancement
5
+ - dereverberation
6
+ - diffusion-models
7
+ - generative-models
8
+ - pytorch
9
+ - audio-processing
10
+ license: mit
11
+ datasets:
12
+ - VoiceBank-DEMAND
13
+ - WSJ0-CHiME3
14
+ - WSJ0-REVERB
15
+ - EARS-WHAM
16
+ - EARS-Reverb
17
+ model_name: speech-enhancement-dereverberation-diffusion
18
+ model_type: diffusion-based-generative-model
19
+ library_name: pytorch
20
+ inference: true
21
+ ---
22
+
23
+
24
+ # Speech Enhancement and Dereverberation with Diffusion-based Generative Models
25
+
26
+ <img src="https://raw.githubusercontent.com/sp-uhh/sgmse/main/diffusion_process.png" width="500" alt="Diffusion process on a spectrogram: In the forward process noise is gradually added to the clean speech spectrogram x0, while the reverse process learns to generate clean speech in an iterative fashion starting from the corrupted signal xT.">
27
+
28
+ This repository contains the official PyTorch implementations for the papers:
29
+
30
+ - Simon Welker, Julius Richter, Timo Gerkmann, [*"Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain"*](https://www.isca-speech.org/archive/interspeech_2022/welker22_interspeech.html), ISCA Interspeech, Incheon, Korea, Sept. 2022. [[bibtex]](#citations--references)
31
+ - Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann, [*"Speech Enhancement and Dereverberation with Diffusion-Based Generative Models"*](https://ieeexplore.ieee.org/abstract/document/10149431), IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023. [[bibtex]](#citations--references)
32
+ - Julius Richter, Yi-Chiao Wu, Steven Krenn, Simon Welker, Bunlong Lay, Shinji Watanabe, Alexander Richard, Timo Gerkmann, [*"EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation"*](https://arxiv.org/abs/2406.06185), ISCA Interspecch, Kos, Greece, Sept. 2024. [[bibtex]](#citations--references)
33
+
34
+ Audio examples and supplementary materials are available on our [SGMSE project page](https://www.inf.uni-hamburg.de/en/inst/ab/sp/publications/sgmse) and [EARS project page](https://sp-uhh.github.io/ears_dataset/).
35
+
36
+ ## Follow-up work
37
+
38
+ Please also check out our follow-up work with code available:
39
+
40
+ - Jean-Marie Lemercier, Julius Richter, Simon Welker, Timo Gerkmann, [*"StoRM: A Diffusion-based Stochastic Regeneration Model for Speech Enhancement and Dereverberation"*](https://ieeexplore.ieee.org/document/10180108), IEEE/ACM Transactions on Audio, Speech, Language Processing, vol. 31, pp. 2724 -2737, 2023. [[github]](https://github.com/sp-uhh/storm)
41
+ - Bunlong Lay, Simon Welker, Julius Richter, Timo Gerkmann, [*"Reducing the Prior Mismatch of Stochastic Differential Equations for Diffusion-based Speech Enhancement"*](https://www.isca-archive.org/interspeech_2023/lay23_interspeech.html), ISCA Interspeech, Dublin, Ireland, Aug. 2023. [[github]](https://github.com/sp-uhh/sgmse-bbed)
42
+
43
+ ## Installation
44
+
45
+ - Create a new virtual environment with Python 3.11 (we have not tested other Python versions, but they may work).
46
+ - Install the package dependencies via `pip install -r requirements.txt`.
47
+ - Let pip resolve the dependencies for you. If you encounter any issues, please check `requirements_version.txt` for the exact versions we used.
48
+ - If using W&B logging (default):
49
+ - Set up a [wandb.ai](https://wandb.ai/) account
50
+ - Log in via `wandb login` before running our code.
51
+ - If not using W&B logging:
52
+ - Pass the option `--nolog` to `train.py`.
53
+ - Your logs will be stored as local CSVLogger logs in `lightning_logs/`.
54
+
55
+ ## Pretrained checkpoints
56
+
57
+ - For the speech enhancement task, we offer pretrained checkpoints for models that have been trained on the VoiceBank-DEMAND and WSJ0-CHiME3 datasets, as described in our journal paper [2]. You can download them [here](https://drive.google.com/drive/folders/1CSnkhUSoiv3RG0xg7WEcVapyLuwDaLbe?usp=sharing).
58
+ - SGMSE+ trained on VoiceBank-DEMAND: `gdown 1_H3EXvhcYBhOZ9QNUcD5VZHc6ktrRbwQ`
59
+ - SGMSE+ trained on WSJ0-CHiME3: `gdown 16K4DUdpmLhDNC7pJhBBc08pkSIn_yMPi`
60
+ - For the dereverberation task, we offer a checkpoint trained on our WSJ0-REVERB dataset. You can download it [here](https://drive.google.com/drive/folders/1082_PSEgrqoVVrNsAkSIcpLF1AAtzGwV?usp=sharing).
61
+ - SGMSE+ trained on WSJ0-REVERB: `gdown 1eiOy0VjHh9V9ZUFTxu1Pq2w19izl9ejD`
62
+ - Note that this checkpoint works better with sampler settings `--N 50 --snr 0.33`.
63
+ - For 48 kHz models [3], we offer pretrained checkpoints for speech enhancement, trained on the EARS-WHAM dataset, and for dereverberation, trained on the EARS-Reverb dataset. You can download them [here](https://drive.google.com/drive/folders/1Tn6pVwjxUAy1DJ8167JCg3enuSi0hiw5?usp=sharing).
64
+ - SGMSE+ trained on EARS-WHAM: `gdown 1t_DLLk8iPH6nj8M5wGeOP3jFPaz3i7K5`
65
+ - SGMSE+ trained on EARS-Reverb: `gdown 1PunXuLbuyGkknQCn_y-RCV2dTZBhyE3V`
66
+
67
+ Usage:
68
+ - For resuming training, you can use the `--ckpt` option of `train.py`.
69
+ - For evaluating these checkpoints, use the `--ckpt` option of `enhancement.py` (see section **Evaluation** below).
70
+
71
+ ## Training
72
+
73
+ Training is done by executing `train.py`. A minimal running example with default settings (as in our paper [2]) can be run with
74
+
75
+ ```bash
76
+ python train.py --base_dir <your_base_dir>
77
+ ```
78
+
79
+ where `your_base_dir` should be a path to a folder containing subdirectories `train/` and `valid/` (optionally `test/` as well). Each subdirectory must itself have two subdirectories `clean/` and `noisy/`, with the same filenames present in both. We currently only support training with `.wav` files.
80
+
81
+ To see all available training options, run `python train.py --help`. Note that the available options for the SDE and the backbone network change depending on which SDE and backbone you use. These can be set through the `--sde` and `--backbone` options.
82
+
83
+ **Note:**
84
+ - Our journal preprint [2] uses `--backbone ncsnpp`.
85
+ - For the 48 kHz model [3], use `--backbone ncsnpp_48k --n_fft 1534 --hop_length 384 --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
86
+ - Our Interspeech paper [1] uses `--backbone dcunet`. You need to pass `--n_fft 512` to make it work.
87
+ - Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (`--spec_factor 0.15` rather than `--spec_factor 0.333`), but we've found the value in this repository to generally perform better for both models [1] and [2].
88
+
89
+ ## Evaluation
90
+
91
+ To evaluate on a test set, run
92
+ ```bash
93
+ python enhancement.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir> --ckpt <path_to_model_checkpoint>
94
+ ```
95
+
96
+ to generate the enhanced .wav files, and subsequently run
97
+
98
+ ```bash
99
+ python calc_metrics.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir>
100
+ ```
101
+
102
+ to calculate and output the instrumental metrics.
103
+
104
+ Both scripts should receive the same `--test_dir` and `--enhanced_dir` parameters. The `--cpkt` parameter of `enhancement.py` should be the path to a trained model checkpoint, as stored by the logger in `logs/`.
105
+
106
+ ## Citations / References
107
+
108
+ We kindly ask you to cite our papers in your publication when using any of our research or code:
109
+ ```bib
110
+ @inproceedings{welker22speech,
111
+ author={Simon Welker and Julius Richter and Timo Gerkmann},
112
+ title={Speech Enhancement with Score-Based Generative Models in the Complex {STFT} Domain},
113
+ year={2022},
114
+ booktitle={Proc. Interspeech 2022},
115
+ pages={2928--2932},
116
+ doi={10.21437/Interspeech.2022-10653}
117
+ }
118
+ ```
119
+ ```bib
120
+ @article{richter2023speech,
121
+ title={Speech Enhancement and Dereverberation with Diffusion-based Generative Models},
122
+ author={Richter, Julius and Welker, Simon and Lemercier, Jean-Marie and Lay, Bunlong and Gerkmann, Timo},
123
+ journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
124
+ volume={31},
125
+ pages={2351-2364},
126
+ year={2023},
127
+ doi={10.1109/TASLP.2023.3285241}
128
+ }
129
+ ```
130
+ ```bib
131
+ @inproceedings{richter2024ears,
132
+ title={{EARS}: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation},
133
+ author={Richter, Julius and Wu, Yi-Chiao and Krenn, Steven and Welker, Simon and Lay, Bunlong and Watanabe, Shinjii and Richard, Alexander and Gerkmann, Timo},
134
+ booktitle={ISCA Interspeech},
135
+ year={2024}
136
+ }
137
+ ```
138
+
139
+ >[1] Simon Welker, Julius Richter, Timo Gerkmann. "Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain", ISCA Interspeech, Incheon, Korea, Sep. 2022.
140
+ >
141
+ >[2] Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann. "Speech Enhancement and Dereverberation with Diffusion-Based Generative Models", IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023.
142
+ >
143
+ >[3] Julius Richter, Yi-Chiao Wu, Steven Krenn, Simon Welker, Bunlong Lay, Shinji Watanabe, Alexander Richard, Timo Gerkmann. "EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation", ISCA Interspeech, Kos, Greece, 2024.
calc_metrics.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join
2
+ from glob import glob
3
+ from argparse import ArgumentParser
4
+ from soundfile import read
5
+ from tqdm import tqdm
6
+ from pesq import pesq
7
+ import pandas as pd
8
+ import librosa
9
+
10
+ from pystoi import stoi
11
+
12
+ from sgmse.util.other import energy_ratios, mean_std
13
+
14
+
15
+ if __name__ == '__main__':
16
+ parser = ArgumentParser()
17
+ parser.add_argument("--clean_dir", type=str, required=True, help='Directory containing the clean data')
18
+ parser.add_argument("--noisy_dir", type=str, required=True, help='Directory containing the noisy data')
19
+ parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
20
+ args = parser.parse_args()
21
+
22
+ data = {"filename": [], "pesq": [], "estoi": [], "si_sdr": [], "si_sir": [], "si_sar": []}
23
+
24
+ # Evaluate standard metrics
25
+ noisy_files = []
26
+ noisy_files += sorted(glob(join(args.noisy_dir, '*.wav')))
27
+ noisy_files += sorted(glob(join(args.noisy_dir, '**', '*.wav')))
28
+ for noisy_file in tqdm(noisy_files):
29
+ filename = noisy_file.replace(args.noisy_dir, "")[1:]
30
+ if 'dB' in filename:
31
+ clean_filename = filename.split("_")[0] + ".wav"
32
+ else:
33
+ clean_filename = filename
34
+ x, sr_x = read(join(args.clean_dir, clean_filename))
35
+ y, sr_y = read(join(args.noisy_dir, filename))
36
+ x_hat, sr_x_hat = read(join(args.enhanced_dir, filename))
37
+ assert sr_x == sr_y == sr_x_hat
38
+ n = y - x
39
+ x_hat_16k = librosa.resample(x_hat, orig_sr=sr_x_hat, target_sr=16000) if sr_x_hat != 16000 else x_hat
40
+ x_16k = librosa.resample(x, orig_sr=sr_x, target_sr=16000) if sr_x != 16000 else x
41
+ data["filename"].append(filename)
42
+ data["pesq"].append(pesq(16000, x_16k, x_hat_16k, 'wb'))
43
+ data["estoi"].append(stoi(x, x_hat, sr_x, extended=True))
44
+ data["si_sdr"].append(energy_ratios(x_hat, x, n)[0])
45
+ data["si_sir"].append(energy_ratios(x_hat, x, n)[1])
46
+ data["si_sar"].append(energy_ratios(x_hat, x, n)[2])
47
+
48
+ # Save results as DataFrame
49
+ df = pd.DataFrame(data)
50
+
51
+ # Print results
52
+ print("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())))
53
+ print("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())))
54
+ print("SI-SDR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sdr"].to_numpy())))
55
+ print("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())))
56
+ print("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())))
57
+
58
+ # Save average results to file
59
+ log = open(join(args.enhanced_dir, "_avg_results.txt"), "w")
60
+ log.write("PESQ: {:.2f} ± {:.2f}".format(*mean_std(df["pesq"].to_numpy())) + "\n")
61
+ log.write("ESTOI: {:.2f} ± {:.2f}".format(*mean_std(df["estoi"].to_numpy())) + "\n")
62
+ log.write("SI-SDR: {:.1f} ± {:.2f}".format(*mean_std(df["si_sdr"].to_numpy())) + "\n")
63
+ log.write("SI-SIR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sir"].to_numpy())) + "\n")
64
+ log.write("SI-SAR: {:.1f} ± {:.1f}".format(*mean_std(df["si_sar"].to_numpy())) + "\n")
65
+
66
+ # Save DataFrame as csv file
67
+ df.to_csv(join(args.enhanced_dir, "_results.csv"), index=False)
diffusion_process.png ADDED
enhancement.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import torch
3
+ from tqdm import tqdm
4
+ from os import makedirs
5
+ from soundfile import write
6
+ from torchaudio import load
7
+ from os.path import join, dirname
8
+ from argparse import ArgumentParser
9
+ from librosa import resample
10
+
11
+ # Set CUDA architecture list
12
+ from sgmse.util.other import set_torch_cuda_arch_list
13
+ set_torch_cuda_arch_list()
14
+
15
+ from sgmse.model import ScoreModel
16
+ from sgmse.util.other import pad_spec
17
+
18
+
19
+ if __name__ == '__main__':
20
+ parser = ArgumentParser()
21
+ parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the test data')
22
+ parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
23
+ parser.add_argument("--ckpt", type=str, help='Path to model checkpoint')
24
+ parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.")
25
+ parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps")
26
+ parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics")
27
+ parser.add_argument("--N", type=int, default=30, help="Number of reverse steps")
28
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference")
29
+ args = parser.parse_args()
30
+
31
+ # Load score model
32
+ model = ScoreModel.load_from_checkpoint(args.ckpt, map_location=args.device)
33
+ model.eval()
34
+
35
+ # Get list of noisy files
36
+ noisy_files = []
37
+ noisy_files += sorted(glob.glob(join(args.test_dir, '*.wav')))
38
+ noisy_files += sorted(glob.glob(join(args.test_dir, '**', '*.wav')))
39
+
40
+ # Check if the model is trained on 48 kHz data
41
+ if model.backbone == 'ncsnpp_48k':
42
+ target_sr = 48000
43
+ pad_mode = "reflection"
44
+ else:
45
+ target_sr = 16000
46
+ pad_mode = "zero_pad"
47
+
48
+ # Enhance files
49
+ for noisy_file in tqdm(noisy_files):
50
+ filename = noisy_file.replace(args.test_dir, "")
51
+ filename = filename[1:] if filename.startswith("/") else filename
52
+
53
+ # Load wav
54
+ y, sr = load(noisy_file)
55
+
56
+ # Resample if necessary
57
+ if sr != target_sr:
58
+ y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
59
+
60
+ T_orig = y.size(1)
61
+
62
+ # Normalize
63
+ norm_factor = y.abs().max()
64
+ y = y / norm_factor
65
+
66
+ # Prepare DNN input
67
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
68
+ Y = pad_spec(Y, mode=pad_mode)
69
+
70
+ # Reverse sampling
71
+ sampler = model.get_pc_sampler(
72
+ 'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N,
73
+ corrector_steps=args.corrector_steps, snr=args.snr)
74
+ sample, _ = sampler()
75
+
76
+ # Backward transform in time domain
77
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
78
+
79
+ # Renormalize
80
+ x_hat = x_hat * norm_factor
81
+
82
+ # Write enhanced wav file
83
+ makedirs(dirname(join(args.enhanced_dir, filename)), exist_ok=True)
84
+ write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), target_sr)
logs/.keep ADDED
@@ -0,0 +1 @@
 
 
1
+
preprocessing/create_wsj0_chime3.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from librosa import load
4
+ from librosa.core import resample
5
+ import argparse
6
+ from argparse import ArgumentParser
7
+ from pathlib import Path
8
+ import numpy as np
9
+ from soundfile import write
10
+ from tqdm import tqdm
11
+
12
+
13
+ # Python script for generating noisy mixtures for training
14
+ #
15
+ # Mix WSJ0 with CHiME3 noise with SNR sampled uniformly in [min_snr, max_snr]
16
+
17
+
18
+ min_snr = 0
19
+ max_snr = 20
20
+ sr = 16000
21
+
22
+
23
+ if __name__ == '__main__':
24
+ parser = ArgumentParser()
25
+ parser.add_argument("wsj0", type=str, help='path to WSJ0 directory')
26
+ parser.add_argument("chime3", type=str, help='path to CHiME3 directory')
27
+ parser.add_argument("target", type=str, help='target path for training files')
28
+ args = parser.parse_args()
29
+
30
+ # Clean speech for training
31
+ train_speech_files = sorted(glob(args.wsj0 + '**/si_tr_s/**/*.wav', recursive=True))
32
+ valid_speech_files = sorted(glob(args.wsj0 + '**/si_dt_05/**/*.wav', recursive=True))
33
+ test_speech_files = sorted(glob(args.wsj0 + '**/si_et_05/**/*.wav', recursive=True))
34
+
35
+ noise_files = glob(args.chime3 + '**/backgrounds/*.wav', recursive=True)
36
+ noise_files = [file for file in noise_files if (file[-7:-4] == "CH1")]
37
+
38
+ # Load CHiME3 noise files
39
+ noises = []
40
+ print('Loading CHiME3 noise files')
41
+ for file in noise_files:
42
+ noise = load(file, sr=None)[0]
43
+ noises.append(noise)
44
+
45
+ # Create target dir
46
+ train_clean_path = Path(os.path.join(args.target, 'train/clean'))
47
+ train_noisy_path = Path(os.path.join(args.target, 'train/noisy'))
48
+ valid_clean_path = Path(os.path.join(args.target, 'valid/clean'))
49
+ valid_noisy_path = Path(os.path.join(args.target, 'valid/noisy'))
50
+ test_clean_path = Path(os.path.join(args.target, 'test/clean'))
51
+ test_noisy_path = Path(os.path.join(args.target, 'test/noisy'))
52
+
53
+ train_clean_path.mkdir(parents=True, exist_ok=True)
54
+ train_noisy_path.mkdir(parents=True, exist_ok=True)
55
+ valid_clean_path.mkdir(parents=True, exist_ok=True)
56
+ valid_noisy_path.mkdir(parents=True, exist_ok=True)
57
+ test_clean_path.mkdir(parents=True, exist_ok=True)
58
+ test_noisy_path.mkdir(parents=True, exist_ok=True)
59
+
60
+ # Initialize seed for reproducability
61
+ np.random.seed(0)
62
+
63
+ # Create files for training
64
+ print('Create training files')
65
+ for i, speech_file in enumerate(tqdm(train_speech_files)):
66
+ s, _ = load(speech_file, sr=sr)
67
+
68
+ snr_dB = np.random.uniform(min_snr, max_snr)
69
+ noise_ind = np.random.randint(len(noises))
70
+ speech_power = 1/len(s)*np.sum(s**2)
71
+
72
+ n = noises[noise_ind]
73
+ start = np.random.randint(len(n)-len(s))
74
+ n = n[start:start+len(s)]
75
+
76
+ noise_power = 1/len(n)*np.sum(n**2)
77
+ noise_power_target = speech_power*np.power(10,-snr_dB/10)
78
+ k = noise_power_target / noise_power
79
+ n = n * np.sqrt(k)
80
+ x = s + n
81
+
82
+ file_name = speech_file.split('/')[-1]
83
+ write(os.path.join(train_clean_path, file_name), s, sr)
84
+ write(os.path.join(train_noisy_path, file_name), x, sr)
85
+
86
+ # Create files for validation
87
+ print('Create validation files')
88
+ for i, speech_file in enumerate(tqdm(valid_speech_files)):
89
+ s, _ = load(speech_file, sr=sr)
90
+
91
+ snr_dB = np.random.uniform(min_snr, max_snr)
92
+ noise_ind = np.random.randint(len(noises))
93
+ speech_power = 1/len(s)*np.sum(s**2)
94
+
95
+ n = noises[noise_ind]
96
+ start = np.random.randint(len(n)-len(s))
97
+ n = n[start:start+len(s)]
98
+
99
+ noise_power = 1/len(n)*np.sum(n**2)
100
+ noise_power_target = speech_power*np.power(10,-snr_dB/10)
101
+ k = noise_power_target / noise_power
102
+ n = n * np.sqrt(k)
103
+ x = s + n
104
+
105
+ file_name = speech_file.split('/')[-1]
106
+ write(os.path.join(valid_clean_path, file_name), s, sr)
107
+ write(os.path.join(valid_noisy_path, file_name), x, sr)
108
+
109
+ # Create files for test
110
+ print('Create test files')
111
+ for i, speech_file in enumerate(tqdm(test_speech_files)):
112
+ s, _ = load(speech_file, sr=sr)
113
+
114
+ snr_dB = np.random.uniform(min_snr, max_snr)
115
+ noise_ind = np.random.randint(len(noises))
116
+ speech_power = 1/len(s)*np.sum(s**2)
117
+
118
+ n = noises[noise_ind]
119
+ start = np.random.randint(len(n)-len(s))
120
+ n = n[start:start+len(s)]
121
+
122
+ noise_power = 1/len(n)*np.sum(n**2)
123
+ noise_power_target = speech_power*np.power(10,-snr_dB/10)
124
+ k = noise_power_target / noise_power
125
+ n = n * np.sqrt(k)
126
+ x = s + n
127
+
128
+ file_name = speech_file.split('/')[-1]
129
+ write(os.path.join(test_clean_path, file_name), s, sr)
130
+ write(os.path.join(test_noisy_path, file_name), x, sr)
preprocessing/create_wsj0_qut.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from librosa import load
4
+ from librosa.core import resample
5
+ import argparse
6
+ from argparse import ArgumentParser
7
+ from pathlib import Path
8
+ import numpy as np
9
+ from soundfile import write
10
+ from tqdm import tqdm
11
+
12
+
13
+ # Python script for generating noisy mixtures for training
14
+ #
15
+ # Mix WSJ0 with QUT noise with SNR sampled uniformly in [min_snr, max_snr]
16
+
17
+
18
+ min_snr = 0
19
+ max_snr = 15
20
+ sr = 16000
21
+
22
+
23
+ if __name__ == '__main__':
24
+ parser = ArgumentParser()
25
+ parser.add_argument("wsj0", type=str, help='path to WSJ0 directory')
26
+ parser.add_argument("qut", type=str, help='path to QUT directory')
27
+ parser.add_argument("target", type=str, help='target path for training files')
28
+ args = parser.parse_args()
29
+
30
+ # Clean speech for training
31
+ train_speech_files = sorted(glob(args.wsj0 + '**/si_tr_s/**/*.wav', recursive=True))
32
+ valid_speech_files = sorted(glob(args.wsj0 + '**/si_dt_05/**/*.wav', recursive=True))
33
+ test_speech_files = sorted(glob(args.wsj0 + '**/si_et_05/**/*.wav', recursive=True))
34
+
35
+ # Load QUT noise files
36
+ print('Loading QUT noise files')
37
+ cafe, sr_QUT = load(glob(args.qut + '**/CAFE-CAFE-1.wav', recursive=True)[0], sr=None)
38
+ car, sr_QUT = load(glob(args.qut + '**/CAR-WINDOWNB-1.wav', recursive=True)[0], sr=None)
39
+ home, sr_QUT = load(glob(args.qut + '**/HOME-KITCHEN-1.wav', recursive=True)[0], sr=None)
40
+ street, sr_QUT = load(glob(args.qut + '**/STREET-CITY-1.wav', recursive=True)[0], sr=None)
41
+
42
+ print('Resampling QUT noise files to 16kHz')
43
+ cafe = resample(cafe, sr_QUT, sr)
44
+ car = resample(car, sr_QUT, sr)
45
+ home = resample(home, sr_QUT, sr)
46
+ street = resample(street, sr_QUT, sr)
47
+
48
+ # ToDo: resampling with ffmpeg bacause librosa is soooo slow
49
+ # cafe, fs_QUT = load(os.path.join(args.qut, 'CAFE-CAFE-1_16k.wav'), sr=None)
50
+ # car, fs_QUT = load(os.path.join(args.qut, 'CAR-WINDOWNB-1_16k.wav'), sr=None)
51
+ # home, fs_QUT = load(os.path.join(args.qut, 'HOME-KITCHEN-1_16k.wav'), sr=None)
52
+ # street, fs_QUT = load(os.path.join(args.qut, 'STREET-CITY-1_16k.wav'), sr=None)
53
+
54
+ # Remove sweeps in the first and last 2 min in car noise file
55
+ car = car[120*sr:-120*sr]
56
+
57
+ # Create target dir
58
+ train_clean_path = Path(os.path.join(args.target, 'train/clean'))
59
+ train_noisy_path = Path(os.path.join(args.target, 'train/noisy'))
60
+ valid_clean_path = Path(os.path.join(args.target, 'valid/clean'))
61
+ valid_noisy_path = Path(os.path.join(args.target, 'valid/noisy'))
62
+ test_clean_path = Path(os.path.join(args.target, 'test/clean'))
63
+ test_noisy_path = Path(os.path.join(args.target, 'test/noisy'))
64
+
65
+ train_clean_path.mkdir(parents=True, exist_ok=True)
66
+ train_noisy_path.mkdir(parents=True, exist_ok=True)
67
+ valid_clean_path.mkdir(parents=True, exist_ok=True)
68
+ valid_noisy_path.mkdir(parents=True, exist_ok=True)
69
+ test_clean_path.mkdir(parents=True, exist_ok=True)
70
+ test_noisy_path.mkdir(parents=True, exist_ok=True)
71
+
72
+ # Initialize seed for reproducability
73
+ np.random.seed(0)
74
+
75
+ # Create files for training
76
+ print('Create training files')
77
+ for i, speech_file in enumerate(tqdm(train_speech_files)):
78
+ s, _ = load(speech_file, sr=sr)
79
+
80
+ snr_dB = np.random.uniform(min_snr, max_snr)
81
+ noise_type = np.random.randint(4)
82
+ speech_power = 1/len(s)*np.sum(s**2)
83
+
84
+ if noise_type == 0:
85
+ start = np.random.randint(len(cafe)-len(s))
86
+ n = cafe[start:start+len(s)]
87
+ elif noise_type == 1:
88
+ start = np.random.randint(len(home)-len(s))
89
+ n = home[start:start+len(s)]
90
+ elif noise_type == 2:
91
+ start = np.random.randint(len(street)-len(s))
92
+ n = street[start:start+len(s)]
93
+ elif noise_type == 3:
94
+ start = np.random.randint(len(car)-len(s))
95
+ n = car[start:start+len(s)]
96
+ else:
97
+ raise ValueError('Unexpected noise type index')
98
+ noise_power = 1/len(n)*np.sum(n**2)
99
+ noise_power_target = speech_power*np.power(10,-snr_dB/10)
100
+ k = noise_power_target / noise_power
101
+ n = n * np.sqrt(k)
102
+ x = s + n
103
+
104
+ file_name = speech_file.split('/')[-1]
105
+ write(os.path.join(train_clean_path, file_name), s, sr)
106
+ write(os.path.join(train_noisy_path, file_name), x, sr)
107
+
108
+ # Create files for validation
109
+ print('Create validation files')
110
+ for i, speech_file in enumerate(tqdm(valid_speech_files)):
111
+ s, _ = load(speech_file, sr=sr)
112
+
113
+ snr_dB = np.random.uniform(min_snr, max_snr)
114
+ noise_type = np.random.randint(4)
115
+ speech_power = 1/len(s)*np.sum(s**2)
116
+
117
+ if noise_type == 0:
118
+ start = np.random.randint(len(cafe)-len(s))
119
+ n = cafe[start:start+len(s)]
120
+ elif noise_type == 1:
121
+ start = np.random.randint(len(home)-len(s))
122
+ n = home[start:start+len(s)]
123
+ elif noise_type == 2:
124
+ start = np.random.randint(len(street)-len(s))
125
+ n = street[start:start+len(s)]
126
+ elif noise_type == 3:
127
+ start = np.random.randint(len(car)-len(s))
128
+ n = car[start:start+len(s)]
129
+ else:
130
+ raise ValueError('Unexpected noise type index')
131
+ noise_power = 1/len(n)*np.sum(n**2)
132
+ noise_power_target = speech_power*np.power(10,-snr_dB/10)
133
+ k = noise_power_target / noise_power
134
+ n = n * np.sqrt(k)
135
+ x = s + n
136
+
137
+ file_name = speech_file.split('/')[-1]
138
+ write(os.path.join(valid_clean_path, file_name), s, sr)
139
+ write(os.path.join(valid_noisy_path, file_name), x, sr)
140
+
141
+ # Create files for test
142
+ print('Create test files')
143
+ for i, speech_file in enumerate(tqdm(test_speech_files)):
144
+ s, _ = load(speech_file, sr=sr)
145
+
146
+ snr_dB = np.random.uniform(min_snr, max_snr)
147
+ noise_type = np.random.randint(4)
148
+ speech_power = 1/len(s)*np.sum(s**2)
149
+
150
+ if noise_type == 0:
151
+ start = np.random.randint(len(cafe)-len(s))
152
+ n = cafe[start:start+len(s)]
153
+ elif noise_type == 1:
154
+ start = np.random.randint(len(home)-len(s))
155
+ n = home[start:start+len(s)]
156
+ elif noise_type == 2:
157
+ start = np.random.randint(len(street)-len(s))
158
+ n = street[start:start+len(s)]
159
+ elif noise_type == 3:
160
+ start = np.random.randint(len(car)-len(s))
161
+ n = car[start:start+len(s)]
162
+ else:
163
+ raise ValueError('Unexpected noise type index')
164
+ noise_power = 1/len(n)*np.sum(n**2)
165
+ noise_power_target = speech_power*np.power(10,-snr_dB/10)
166
+ k = noise_power_target / noise_power
167
+ n = n * np.sqrt(k)
168
+ x = s + n
169
+
170
+ file_name = speech_file.split('/')[-1]
171
+ write(os.path.join(test_clean_path, file_name), s, sr)
172
+ write(os.path.join(test_noisy_path, file_name), x, sr)
preprocessing/create_wsj0_reverb.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import argparse
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import pyroomacoustics as pra
7
+ from glob import glob
8
+ from tqdm import tqdm
9
+
10
+
11
+ SEED = 100
12
+ np.random.seed(SEED)
13
+
14
+ T60_RANGE = [0.4, 1.0]
15
+ SNR_RANGE = [0, 20]
16
+ DIM_RANGE = [5, 15, 5, 15, 2, 6]
17
+ MIN_DISTANCE_TO_WALL = 1
18
+ MIC_ARRAY_RADIUS = 0.16
19
+ TARGET_T60_SHAPE = {"CI": 0.08, "HA": 0.2}
20
+ TARGET_T60_SHAPE = {"CI": 0.10, "HA": 0.2}
21
+ TARGETS_CROP = {"CI": 16e-3, "HA": 40e-3}
22
+ NB_SAMPLES_PER_ROOM = 1
23
+ CHANNELS = 1
24
+
25
+ if __name__ == '__main__':
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument('--wsj0_dir', type=str, required=True, help='Path to the WSJ0 directory which should contain subdirectories "si_dt_05", "si_tr_s" and "si_et_05".')
28
+ parser.add_argument('--target_dir', type=str, required=True, help='Path to the target directory for saving WSJ0-REVERB.')
29
+ args = parser.parse_args()
30
+
31
+ def obtain_clean_file(speech_list, i_sample, sample_rate=16000):
32
+ speech, speech_sr = sf.read(speech_list[i_sample])
33
+ speech_basename = os.path.basename(speech_list[i_sample])
34
+ assert speech_sr == sample_rate, f"wrong speech sampling rate here: expected {sample_rate} got {speech_sr}"
35
+ return speech.squeeze(), speech_sr, speech_basename[: -4]
36
+
37
+ splits = ['valid', 'train', 'test']
38
+ dic_split = {"valid": "si_dt_05", "train": "si_tr_s", "test": "si_et_05"}
39
+ speech_lists = {split:sorted(glob(f"{os.path.join(args.wsj0_dir, dic_split[split])}/**/*.wav")) for split in splits}
40
+ sample_rate = 16000
41
+ output_dir = args.target_dir
42
+
43
+ if os.path.exists(output_dir):
44
+ shutil.rmtree(output_dir)
45
+
46
+ for i_split, split in enumerate(splits):
47
+ print("Processing split n° {}: {}...".format(i_split+1, split))
48
+
49
+ reverberant_output_dir = os.path.join(output_dir, "audio", split, "reverb")
50
+ dry_output_dir = os.path.join(output_dir, "audio", split, "anechoic")
51
+ noisy_reverberant_output_dir = os.path.join(output_dir, "audio", split, "noisy_reverb")
52
+ if split == "test":
53
+ unauralized_output_dir = os.path.join(output_dir, "audio", split, "unauralized")
54
+
55
+ os.makedirs(reverberant_output_dir, exist_ok=True)
56
+ os.makedirs(dry_output_dir, exist_ok=True)
57
+ if split == "test":
58
+ os.makedirs(unauralized_output_dir, exist_ok=True)
59
+
60
+ speech_list = speech_lists[split]
61
+ speech_dir = None
62
+ real_nb_samples = len(speech_list)
63
+
64
+ for i_sample in tqdm(range(real_nb_samples)):
65
+ if not i_sample % NB_SAMPLES_PER_ROOM: #Generate new room
66
+ t60 = np.random.uniform(T60_RANGE[0], T60_RANGE[1]) #Draw T60
67
+ room_dim = np.array([ np.random.uniform(DIM_RANGE[2*n], DIM_RANGE[2*n+1]) for n in range(3) ]) #Draw Dimensions
68
+ center_mic_position = np.array([ np.random.uniform(MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL) for n in range(3) ]) #draw source position
69
+ source_position = np.array([ np.random.uniform(MIN_DISTANCE_TO_WALL, room_dim[n] - MIN_DISTANCE_TO_WALL) for n in range(3) ]) #draw source position
70
+ mic_array_2d = pra.beamforming.circular_2D_array(center_mic_position[: -1], CHANNELS, phi0=0, radius=MIC_ARRAY_RADIUS) # Compute microphone array
71
+ mic_array = np.pad(mic_array_2d, ((0, 1), (0, 0)), mode="constant", constant_values=center_mic_position[-1])
72
+
73
+ ### Reverberant Room
74
+ e_absorption, max_order = pra.inverse_sabine(t60, room_dim) #Compute absorption coeff
75
+ reverberant_room = pra.ShoeBox(
76
+ room_dim, fs=16000, materials=pra.Material(e_absorption), max_order=min(3, max_order)
77
+ ) # Create room
78
+ reverberant_room.set_ray_tracing()
79
+ reverberant_room.add_microphone_array(mic_array) # Add microphone array
80
+
81
+ # Pick unauralized files
82
+ speech, speech_sr, speech_basename = obtain_clean_file(speech_list, i_sample, sample_rate=sample_rate)
83
+
84
+ # Generate reverberant room
85
+ reverberant_room.add_source(source_position, signal=speech)
86
+ reverberant_room.compute_rir()
87
+ reverberant_room.simulate()
88
+ t60_real = np.mean(reverberant_room.measure_rt60()).squeeze()
89
+ reverberant = np.stack(reverberant_room.mic_array.signals).swapaxes(0, 1)
90
+
91
+ e_absorption_dry = 0.99 #For Neural Networks OK but clearly not for WPE
92
+ dry_room = pra.ShoeBox(
93
+ room_dim, fs=16000, materials=pra.Material(e_absorption_dry), max_order=0
94
+ ) # Create room
95
+ dry_room.add_microphone_array(mic_array) # Add microphone array
96
+
97
+ # Generate dry room
98
+ dry_room.add_source(source_position, signal=speech)
99
+ dry_room.compute_rir()
100
+ dry_room.simulate()
101
+ t60_real_dry = np.mean(dry_room.measure_rt60()).squeeze()
102
+ rir_dry = dry_room.rir
103
+ dry = np.stack(dry_room.mic_array.signals).swapaxes(0, 1)
104
+ dry = np.pad(dry, ((0, int(.5*sample_rate)), (0, 0)), mode="constant", constant_values=0) #Add 1 second of silence after dry (because very dry) so that the reverb is not cut, and all samples have same length
105
+
106
+ min_len_sample = min(reverberant.shape[0], dry.shape[0])
107
+ dry = dry[: min_len_sample]
108
+ reverberant = reverberant[: min_len_sample]
109
+ output_scaling = np.max(reverberant) / .9
110
+
111
+ drr = 10*np.log10( np.mean(dry**2) / (np.mean(reverberant**2) + 1e-8) + 1e-8 )
112
+ output_filename = f"{speech_basename}_{i_sample//NB_SAMPLES_PER_ROOM}_{t60_real:.2f}_{drr:.1f}.wav"
113
+
114
+ sf.write(os.path.join(dry_output_dir, output_filename), 1/output_scaling*dry, samplerate=sample_rate)
115
+ sf.write(os.path.join(reverberant_output_dir, output_filename), 1/output_scaling*reverberant, samplerate=sample_rate)
116
+
117
+ if split == "test":
118
+ sf.write(os.path.join(unauralized_output_dir, output_filename), speech, samplerate=sample_rate)
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gdown
2
+ h5py
3
+ ipympl
4
+ librosa
5
+ ninja
6
+ numpy<2.0
7
+ pandas
8
+ pesq
9
+ pillow
10
+ protobuf
11
+ pyarrow
12
+ pyroomacoustics
13
+ pystoi
14
+ pytorch-lightning
15
+ scipy
16
+ sdeint
17
+ setuptools
18
+ seaborn
19
+ torch
20
+ torch-ema
21
+ torchaudio
22
+ torchvision
23
+ torchinfo
24
+ torchsde
25
+ tqdm
26
+ wandb
requirements_version.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h5py==3.10.0
2
+ ipympl==0.9.3
3
+ librosa==0.10.1
4
+ ninja==1.11.1.1
5
+ numpy==1.24.4
6
+ pandas==2.0.3
7
+ pesq==0.0.4
8
+ pillow==10.2.0
9
+ protobuf==4.25.2
10
+ pyarrow==15.0.0
11
+ pyroomacoustics==0.7.3
12
+ pystoi==0.4.1
13
+ pytorch-lightning==2.1.4
14
+ scipy==1.10.1
15
+ sdeint==0.3.0
16
+ setuptools==44.0.0
17
+ seaborn==0.13.2
18
+ torch==2.2.0
19
+ torch-ema==0.3
20
+ torchaudio==2.2.0
21
+ torchvision==0.17.0
22
+ torchinfo==1.8.0
23
+ torchsde==0.2.6
24
+ tqdm==4.66.1
25
+ wandb==0.16.2
sgmse/backbones/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .shared import BackboneRegistry
2
+ from .ncsnpp import NCSNpp
3
+ from .ncsnpp_48k import NCSNpp_48k
4
+ from .dcunet import DCUNet
5
+
6
+ __all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_48k', 'DCUNet']
sgmse/backbones/dcunet.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from .shared import BackboneRegistry, ComplexConv2d, ComplexConvTranspose2d, ComplexLinear, \
9
+ DiffusionStepEmbedding, GaussianFourierProjection, FeatureMapDense, torch_complex_from_reim
10
+
11
+
12
+ def get_activation(name):
13
+ if name == "silu":
14
+ return nn.SiLU
15
+ elif name == "relu":
16
+ return nn.ReLU
17
+ elif name == "leaky_relu":
18
+ return nn.LeakyReLU
19
+ else:
20
+ raise NotImplementedError(f"Unknown activation: {name}")
21
+
22
+
23
+ class BatchNorm(_BatchNorm):
24
+ def _check_input_dim(self, input):
25
+ if input.dim() < 2 or input.dim() > 4:
26
+ raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
27
+
28
+
29
+ class OnReIm(nn.Module):
30
+ def __init__(self, module_cls, *args, **kwargs):
31
+ super().__init__()
32
+ self.re_module = module_cls(*args, **kwargs)
33
+ self.im_module = module_cls(*args, **kwargs)
34
+
35
+ def forward(self, x):
36
+ return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
37
+
38
+
39
+ # Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
40
+
41
+ def unet_decoder_args(encoders, *, skip_connections):
42
+ """Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
43
+ given the arguments used to construct the encoder.
44
+ Args:
45
+ encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
46
+ List of arguments used to construct the encoders
47
+ skip_connections (bool): Whether to include skip connections in the
48
+ calculation of decoder input channels.
49
+ Return:
50
+ tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
51
+ Arguments to be used to construct decoders
52
+ """
53
+ decoder_args = []
54
+ for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation in reversed(encoders):
55
+ if skip_connections and decoder_args:
56
+ skip_in_chan = enc_out_chan
57
+ else:
58
+ skip_in_chan = 0
59
+ decoder_args.append(
60
+ (enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation)
61
+ )
62
+ return tuple(decoder_args)
63
+
64
+
65
+ def make_unet_encoder_decoder_args(encoder_args, decoder_args):
66
+ encoder_args = tuple(
67
+ (
68
+ in_chan,
69
+ out_chan,
70
+ tuple(kernel_size),
71
+ tuple(stride),
72
+ tuple([n // 2 for n in kernel_size]) if padding == "auto" else tuple(padding),
73
+ tuple(dilation)
74
+ )
75
+ for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
76
+ )
77
+
78
+ if decoder_args == "auto":
79
+ decoder_args = unet_decoder_args(
80
+ encoder_args,
81
+ skip_connections=True,
82
+ )
83
+ else:
84
+ decoder_args = tuple(
85
+ (
86
+ in_chan,
87
+ out_chan,
88
+ tuple(kernel_size),
89
+ tuple(stride),
90
+ tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
91
+ tuple(dilation),
92
+ output_padding,
93
+ )
94
+ for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
95
+ )
96
+
97
+ return encoder_args, decoder_args
98
+
99
+
100
+ DCUNET_ARCHITECTURES = {
101
+ "DCUNet-10": make_unet_encoder_decoder_args(
102
+ # Encoders:
103
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
104
+ (
105
+ (1, 32, (7, 5), (2, 2), "auto", (1,1)),
106
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
107
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
108
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
109
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
110
+ ),
111
+ # Decoders: automatic inverse
112
+ "auto",
113
+ ),
114
+ "DCUNet-16": make_unet_encoder_decoder_args(
115
+ # Encoders:
116
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
117
+ (
118
+ (1, 32, (7, 5), (2, 2), "auto", (1,1)),
119
+ (32, 32, (7, 5), (2, 1), "auto", (1,1)),
120
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
121
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
122
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
123
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
124
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
125
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
126
+ ),
127
+ # Decoders: automatic inverse
128
+ "auto",
129
+ ),
130
+ "DCUNet-20": make_unet_encoder_decoder_args(
131
+ # Encoders:
132
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
133
+ (
134
+ (1, 32, (7, 1), (1, 1), "auto", (1,1)),
135
+ (32, 32, (1, 7), (1, 1), "auto", (1,1)),
136
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
137
+ (64, 64, (7, 5), (2, 1), "auto", (1,1)),
138
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
139
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
140
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
141
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
142
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
143
+ (64, 90, (5, 3), (2, 1), "auto", (1,1)),
144
+ ),
145
+ # Decoders: automatic inverse
146
+ "auto",
147
+ ),
148
+ "DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper
149
+ # Encoders:
150
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
151
+ (
152
+ (1, 32, (4, 4), (1, 1), "auto", (1, 1)),
153
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
154
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
155
+ (32, 64, (4, 4), (2, 1), "auto", (2, 1)),
156
+ (64, 128, (4, 4), (2, 2), "auto", (4, 1)),
157
+ (128, 256, (4, 4), (2, 2), "auto", (8, 1)),
158
+ ),
159
+ # Decoders: automatic inverse
160
+ "auto",
161
+ ),
162
+ }
163
+
164
+
165
+ @BackboneRegistry.register("dcunet")
166
+ class DCUNet(nn.Module):
167
+ @staticmethod
168
+ def add_argparse_args(parser):
169
+ parser.add_argument("--dcunet-architecture", type=str, default="DilDCUNet-v2", choices=DCUNET_ARCHITECTURES.keys(), help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.")
170
+ parser.add_argument("--dcunet-time-embedding", type=str, choices=("gfp", "ds", "none"), default="gfp", help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.")
171
+ parser.add_argument("--dcunet-temb-layers-global", type=int, default=1, help="Number of global linear+activation layers for the time embedding. 1 by default.")
172
+ parser.add_argument("--dcunet-temb-layers-local", type=int, default=1, help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.")
173
+ parser.add_argument("--dcunet-temb-activation", type=str, default="silu", help="The (complex) activation to use between all (global&local) time embedding layers.")
174
+ parser.add_argument("--dcunet-time-embedding-complex", action="store_true", help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.")
175
+ parser.add_argument("--dcunet-fix-length", type=str, default="pad", choices=("pad", "trim", "none"), help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.")
176
+ parser.add_argument("--dcunet-mask-bound", type=str, choices=("tanh", "sigmoid", "none"), default="none", help="DCUNet output bounding strategy. 'none' by default.")
177
+ parser.add_argument("--dcunet-norm-type", type=str, choices=("bN", "CbN"), default="bN", help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.")
178
+ parser.add_argument("--dcunet-activation", type=str, choices=("leaky_relu", "relu", "silu"), default="leaky_relu", help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.")
179
+ return parser
180
+
181
+ def __init__(
182
+ self,
183
+ dcunet_architecture: str = "DilDCUNet-v2",
184
+ dcunet_time_embedding: str = "gfp",
185
+ dcunet_temb_layers_global: int = 2,
186
+ dcunet_temb_layers_local: int = 1,
187
+ dcunet_temb_activation: str = "silu",
188
+ dcunet_time_embedding_complex: bool = False,
189
+ dcunet_fix_length: str = "pad",
190
+ dcunet_mask_bound: str = "none",
191
+ dcunet_norm_type: str = "bN",
192
+ dcunet_activation: str = "relu",
193
+ embed_dim: int = 128,
194
+ **kwargs
195
+ ):
196
+ super().__init__()
197
+
198
+ self.architecture = dcunet_architecture
199
+ self.fix_length_mode = (dcunet_fix_length if dcunet_fix_length != "none" else None)
200
+ self.norm_type = dcunet_norm_type
201
+ self.activation = dcunet_activation
202
+ self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
203
+ self.time_embedding = (dcunet_time_embedding if dcunet_time_embedding != "none" else None)
204
+ self.time_embedding_complex = dcunet_time_embedding_complex
205
+ self.temb_layers_global = dcunet_temb_layers_global
206
+ self.temb_layers_local = dcunet_temb_layers_local
207
+ self.temb_activation = dcunet_temb_activation
208
+ conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
209
+
210
+ # Replace `input_channels` in encoders config
211
+ _replaced_input_channels, *rest = conf_encoders[0]
212
+ encoders = ((self.input_channels, *rest), *conf_encoders[1:])
213
+ decoders = conf_decoders
214
+ self.encoders_stride_product = np.prod(
215
+ [enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
216
+ )
217
+
218
+ # Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
219
+ encoder_decoder_kwargs = dict(
220
+ norm_type=self.norm_type, activation=self.activation,
221
+ temb_layers=self.temb_layers_local, temb_activation=self.temb_activation)
222
+
223
+ # Instantiate (global) time embedding layer
224
+ embed_ops = []
225
+ if self.time_embedding is not None:
226
+ complex_valued = self.time_embedding_complex
227
+ if self.time_embedding == "gfp":
228
+ embed_ops += [GaussianFourierProjection(embed_dim=embed_dim, complex_valued=complex_valued)]
229
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
230
+ elif self.time_embedding == "ds":
231
+ embed_ops += [DiffusionStepEmbedding(embed_dim=embed_dim, complex_valued=complex_valued)]
232
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
233
+
234
+ if self.time_embedding_complex:
235
+ assert self.time_embedding in ("gfp", "ds"), "Complex timestep embedding only available for gfp and ds"
236
+ encoder_decoder_kwargs["complex_time_embedding"] = True
237
+ for _ in range(self.temb_layers_global):
238
+ embed_ops += [
239
+ ComplexLinear(embed_dim, embed_dim, complex_valued=True),
240
+ OnReIm(get_activation(dcunet_temb_activation))
241
+ ]
242
+ self.embed = nn.Sequential(*embed_ops)
243
+
244
+ ### Instantiate DCUNet layers ###
245
+ output_layer = ComplexConvTranspose2d(*decoders[-1])
246
+ encoders = [DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) for args in encoders]
247
+ decoders = [DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) for args in decoders[:-1]]
248
+
249
+ self.mask_bound = (dcunet_mask_bound if dcunet_mask_bound != "none" else None)
250
+ if self.mask_bound is not None:
251
+ raise NotImplementedError("sorry, mask bounding not implemented at the moment")
252
+ # TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
253
+ #operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
254
+ #output_layer = nn.Sequential(*[x for x in operations if x is not None])
255
+
256
+ assert len(encoders) == len(decoders) + 1
257
+ self.encoders = nn.ModuleList(encoders)
258
+ self.decoders = nn.ModuleList(decoders)
259
+ self.output_layer = output_layer or nn.Identity()
260
+
261
+ def forward(self, spec, t) -> Tensor:
262
+ """
263
+ Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
264
+ by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
265
+ and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
266
+ strides of the encoders.
267
+ Args:
268
+ spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
269
+ Returns:
270
+ Tensor, of shape (batch, time) or (time).
271
+ """
272
+ # TF-rep shape: (batch, self.input_channels, n_fft, frames)
273
+ # Estimate mask from time-frequency representation.
274
+ x_in = self.fix_input_dims(spec)
275
+ x = x_in
276
+ t_embed = self.embed(t+0j) if self.time_embedding is not None else None
277
+
278
+ enc_outs = []
279
+ for idx, enc in enumerate(self.encoders):
280
+ x = enc(x, t_embed)
281
+ # UNet skip connection
282
+ enc_outs.append(x)
283
+ for (enc_out, dec) in zip(reversed(enc_outs[:-1]), self.decoders):
284
+ x = dec(x, t_embed, output_size=enc_out.shape)
285
+ x = torch.cat([x, enc_out], dim=1)
286
+
287
+ output = self.output_layer(x, output_size=x_in.shape)
288
+ # output shape: (batch, 1, n_fft, frames)
289
+ output = self.fix_output_dims(output, spec)
290
+ return output
291
+
292
+ def fix_input_dims(self, x):
293
+ return _fix_dcu_input_dims(
294
+ self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
295
+ )
296
+
297
+ def fix_output_dims(self, out, x):
298
+ return _fix_dcu_output_dims(self.fix_length_mode, out, x)
299
+
300
+
301
+ def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
302
+ """Pad or trim `x` to a length compatible with DCUNet."""
303
+ freq_prod = int(encoders_stride_product[0])
304
+ time_prod = int(encoders_stride_product[1])
305
+ if (x.shape[2] - 1) % freq_prod:
306
+ raise TypeError(
307
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
308
+ f"{freq_prod}, got {x.shape} instead"
309
+ )
310
+ time_remainder = (x.shape[3] - 1) % time_prod
311
+ if time_remainder:
312
+ if fix_length_mode is None:
313
+ raise TypeError(
314
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
315
+ f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
316
+ f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
317
+ )
318
+ elif fix_length_mode == "pad":
319
+ pad_shape = [0, time_prod - time_remainder]
320
+ x = nn.functional.pad(x, pad_shape, mode="constant")
321
+ elif fix_length_mode == "trim":
322
+ pad_shape = [0, -time_remainder]
323
+ x = nn.functional.pad(x, pad_shape, mode="constant")
324
+ else:
325
+ raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
326
+ return x
327
+
328
+
329
+ def _fix_dcu_output_dims(fix_length_mode, out, x):
330
+ """Fix shape of `out` to the original shape of `x` by padding/cropping."""
331
+ inp_len = x.shape[-1]
332
+ output_len = out.shape[-1]
333
+ return nn.functional.pad(out, [0, inp_len - output_len])
334
+
335
+
336
+ def _get_norm(norm_type):
337
+ if norm_type == "CbN":
338
+ return ComplexBatchNorm
339
+ elif norm_type == "bN":
340
+ return partial(OnReIm, BatchNorm)
341
+ else:
342
+ raise NotImplementedError(f"Unknown norm type: {norm_type}")
343
+
344
+
345
+ class DCUNetComplexEncoderBlock(nn.Module):
346
+ def __init__(
347
+ self,
348
+ in_chan,
349
+ out_chan,
350
+ kernel_size,
351
+ stride,
352
+ padding,
353
+ dilation,
354
+ norm_type="bN",
355
+ activation="leaky_relu",
356
+ embed_dim=None,
357
+ complex_time_embedding=False,
358
+ temb_layers=1,
359
+ temb_activation="silu"
360
+ ):
361
+ super().__init__()
362
+
363
+ self.in_chan = in_chan
364
+ self.out_chan = out_chan
365
+ self.kernel_size = kernel_size
366
+ self.stride = stride
367
+ self.padding = padding
368
+ self.dilation = dilation
369
+ self.temb_layers = temb_layers
370
+ self.temb_activation = temb_activation
371
+ self.complex_time_embedding = complex_time_embedding
372
+
373
+ self.conv = ComplexConv2d(
374
+ in_chan, out_chan, kernel_size, stride, padding, bias=norm_type is None, dilation=dilation
375
+ )
376
+ self.norm = _get_norm(norm_type)(out_chan)
377
+ self.activation = OnReIm(get_activation(activation))
378
+ self.embed_dim = embed_dim
379
+ if self.embed_dim is not None:
380
+ ops = []
381
+ for _ in range(max(0, self.temb_layers - 1)):
382
+ ops += [
383
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
384
+ OnReIm(get_activation(self.temb_activation))
385
+ ]
386
+ ops += [
387
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
388
+ OnReIm(get_activation(self.temb_activation))
389
+ ]
390
+ self.embed_layer = nn.Sequential(*ops)
391
+
392
+ def forward(self, x, t_embed):
393
+ y = self.conv(x)
394
+ if self.embed_dim is not None:
395
+ y = y + self.embed_layer(t_embed)
396
+ return self.activation(self.norm(y))
397
+
398
+
399
+ class DCUNetComplexDecoderBlock(nn.Module):
400
+ def __init__(
401
+ self,
402
+ in_chan,
403
+ out_chan,
404
+ kernel_size,
405
+ stride,
406
+ padding,
407
+ dilation,
408
+ output_padding=(0, 0),
409
+ norm_type="bN",
410
+ activation="leaky_relu",
411
+ embed_dim=None,
412
+ temb_layers=1,
413
+ temb_activation='swish',
414
+ complex_time_embedding=False,
415
+ ):
416
+ super().__init__()
417
+
418
+ self.in_chan = in_chan
419
+ self.out_chan = out_chan
420
+ self.kernel_size = kernel_size
421
+ self.stride = stride
422
+ self.padding = padding
423
+ self.dilation = dilation
424
+ self.output_padding = output_padding
425
+ self.complex_time_embedding = complex_time_embedding
426
+ self.temb_layers = temb_layers
427
+ self.temb_activation = temb_activation
428
+
429
+ self.deconv = ComplexConvTranspose2d(
430
+ in_chan, out_chan, kernel_size, stride, padding, output_padding, dilation=dilation, bias=norm_type is None
431
+ )
432
+ self.norm = _get_norm(norm_type)(out_chan)
433
+ self.activation = OnReIm(get_activation(activation))
434
+ self.embed_dim = embed_dim
435
+ if self.embed_dim is not None:
436
+ ops = []
437
+ for _ in range(max(0, self.temb_layers - 1)):
438
+ ops += [
439
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
440
+ OnReIm(get_activation(self.temb_activation))
441
+ ]
442
+ ops += [
443
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
444
+ OnReIm(get_activation(self.temb_activation))
445
+ ]
446
+ self.embed_layer = nn.Sequential(*ops)
447
+
448
+ def forward(self, x, t_embed, output_size=None):
449
+ y = self.deconv(x, output_size=output_size)
450
+ if self.embed_dim is not None:
451
+ y = y + self.embed_layer(t_embed)
452
+ return self.activation(self.norm(y))
453
+
454
+
455
+ # From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
456
+ class ComplexBatchNorm(torch.nn.Module):
457
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=False):
458
+ super(ComplexBatchNorm, self).__init__()
459
+ self.num_features = num_features
460
+ self.eps = eps
461
+ self.momentum = momentum
462
+ self.affine = affine
463
+ self.track_running_stats = track_running_stats
464
+ if self.affine:
465
+ self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
466
+ self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
467
+ self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
468
+ self.Br = torch.nn.Parameter(torch.Tensor(num_features))
469
+ self.Bi = torch.nn.Parameter(torch.Tensor(num_features))
470
+ else:
471
+ self.register_parameter('Wrr', None)
472
+ self.register_parameter('Wri', None)
473
+ self.register_parameter('Wii', None)
474
+ self.register_parameter('Br', None)
475
+ self.register_parameter('Bi', None)
476
+ if self.track_running_stats:
477
+ self.register_buffer('RMr', torch.zeros(num_features))
478
+ self.register_buffer('RMi', torch.zeros(num_features))
479
+ self.register_buffer('RVrr', torch.ones (num_features))
480
+ self.register_buffer('RVri', torch.zeros(num_features))
481
+ self.register_buffer('RVii', torch.ones (num_features))
482
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
483
+ else:
484
+ self.register_parameter('RMr', None)
485
+ self.register_parameter('RMi', None)
486
+ self.register_parameter('RVrr', None)
487
+ self.register_parameter('RVri', None)
488
+ self.register_parameter('RVii', None)
489
+ self.register_parameter('num_batches_tracked', None)
490
+ self.reset_parameters()
491
+
492
+ def reset_running_stats(self):
493
+ if self.track_running_stats:
494
+ self.RMr.zero_()
495
+ self.RMi.zero_()
496
+ self.RVrr.fill_(1)
497
+ self.RVri.zero_()
498
+ self.RVii.fill_(1)
499
+ self.num_batches_tracked.zero_()
500
+
501
+ def reset_parameters(self):
502
+ self.reset_running_stats()
503
+ if self.affine:
504
+ self.Br.data.zero_()
505
+ self.Bi.data.zero_()
506
+ self.Wrr.data.fill_(1)
507
+ self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
508
+ self.Wii.data.fill_(1)
509
+
510
+ def _check_input_dim(self, xr, xi):
511
+ assert(xr.shape == xi.shape)
512
+ assert(xr.size(1) == self.num_features)
513
+
514
+ def forward(self, x):
515
+ xr, xi = x.real, x.imag
516
+ self._check_input_dim(xr, xi)
517
+
518
+ exponential_average_factor = 0.0
519
+
520
+ if self.training and self.track_running_stats:
521
+ self.num_batches_tracked += 1
522
+ if self.momentum is None: # use cumulative moving average
523
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
524
+ else: # use exponential moving average
525
+ exponential_average_factor = self.momentum
526
+
527
+ #
528
+ # NOTE: The precise meaning of the "training flag" is:
529
+ # True: Normalize using batch statistics, update running statistics
530
+ # if they are being collected.
531
+ # False: Normalize using running statistics, ignore batch statistics.
532
+ #
533
+ training = self.training or not self.track_running_stats
534
+ redux = [i for i in reversed(range(xr.dim())) if i!=1]
535
+ vdim = [1] * xr.dim()
536
+ vdim[1] = xr.size(1)
537
+
538
+ #
539
+ # Mean M Computation and Centering
540
+ #
541
+ # Includes running mean update if training and running.
542
+ #
543
+ if training:
544
+ Mr, Mi = xr, xi
545
+ for d in redux:
546
+ Mr = Mr.mean(d, keepdim=True)
547
+ Mi = Mi.mean(d, keepdim=True)
548
+ if self.track_running_stats:
549
+ self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
550
+ self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
551
+ else:
552
+ Mr = self.RMr.view(vdim)
553
+ Mi = self.RMi.view(vdim)
554
+ xr, xi = xr-Mr, xi-Mi
555
+
556
+ #
557
+ # Variance Matrix V Computation
558
+ #
559
+ # Includes epsilon numerical stabilizer/Tikhonov regularizer.
560
+ # Includes running variance update if training and running.
561
+ #
562
+ if training:
563
+ Vrr = xr * xr
564
+ Vri = xr * xi
565
+ Vii = xi * xi
566
+ for d in redux:
567
+ Vrr = Vrr.mean(d, keepdim=True)
568
+ Vri = Vri.mean(d, keepdim=True)
569
+ Vii = Vii.mean(d, keepdim=True)
570
+ if self.track_running_stats:
571
+ self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
572
+ self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
573
+ self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
574
+ else:
575
+ Vrr = self.RVrr.view(vdim)
576
+ Vri = self.RVri.view(vdim)
577
+ Vii = self.RVii.view(vdim)
578
+ Vrr = Vrr + self.eps
579
+ Vri = Vri
580
+ Vii = Vii + self.eps
581
+
582
+ #
583
+ # Matrix Inverse Square Root U = V^-0.5
584
+ #
585
+ # sqrt of a 2x2 matrix,
586
+ # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
587
+ tau = Vrr + Vii
588
+ delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
589
+ s = delta.sqrt()
590
+ t = (tau + 2*s).sqrt()
591
+
592
+ # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
593
+ rst = (s * t).reciprocal()
594
+ Urr = (s + Vii) * rst
595
+ Uii = (s + Vrr) * rst
596
+ Uri = ( - Vri) * rst
597
+
598
+ #
599
+ # Optionally left-multiply U by affine weights W to produce combined
600
+ # weights Z, left-multiply the inputs by Z, then optionally bias them.
601
+ #
602
+ # y = Zx + B
603
+ # y = WUx + B
604
+ # y = [Wrr Wri][Urr Uri] [xr] + [Br]
605
+ # [Wir Wii][Uir Uii] [xi] [Bi]
606
+ #
607
+ if self.affine:
608
+ Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
609
+ Zrr = (Wrr * Urr) + (Wri * Uri)
610
+ Zri = (Wrr * Uri) + (Wri * Uii)
611
+ Zir = (Wri * Urr) + (Wii * Uri)
612
+ Zii = (Wri * Uri) + (Wii * Uii)
613
+ else:
614
+ Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
615
+
616
+ yr = (Zrr * xr) + (Zri * xi)
617
+ yi = (Zir * xr) + (Zii * xi)
618
+
619
+ if self.affine:
620
+ yr = yr + self.Br.view(vdim)
621
+ yi = yi + self.Bi.view(vdim)
622
+
623
+ return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
624
+
625
+ def extra_repr(self):
626
+ return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
627
+ 'track_running_stats={track_running_stats}'.format(**self.__dict__)
sgmse/backbones/ncsnpp.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp")
37
+ class NCSNpp(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
+ parser.add_argument("--num_res_blocks", type=int, default=2)
44
+ parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
45
+ parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
46
+ parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
47
+ parser.set_defaults(centered=True)
48
+ return parser
49
+
50
+ def __init__(self,
51
+ scale_by_sigma = True,
52
+ nonlinearity = 'swish',
53
+ nf = 128,
54
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
55
+ num_res_blocks = 2,
56
+ attn_resolutions = (16,),
57
+ resamp_with_conv = True,
58
+ conditional = True,
59
+ fir = True,
60
+ fir_kernel = [1, 3, 3, 1],
61
+ skip_rescale = True,
62
+ resblock_type = 'biggan',
63
+ progressive = 'output_skip',
64
+ progressive_input = 'input_skip',
65
+ progressive_combine = 'sum',
66
+ init_scale = 0.,
67
+ fourier_scale = 16,
68
+ image_size = 256,
69
+ embedding_type = 'fourier',
70
+ dropout = .0,
71
+ centered = True,
72
+ **unused_kwargs
73
+ ):
74
+ super().__init__()
75
+ self.act = act = get_act(nonlinearity)
76
+
77
+ self.nf = nf = nf
78
+ ch_mult = ch_mult
79
+ self.num_res_blocks = num_res_blocks = num_res_blocks
80
+ self.attn_resolutions = attn_resolutions = attn_resolutions
81
+ dropout = dropout
82
+ resamp_with_conv = resamp_with_conv
83
+ self.num_resolutions = num_resolutions = len(ch_mult)
84
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
85
+
86
+ self.conditional = conditional = conditional # noise-conditional
87
+ self.centered = centered
88
+ self.scale_by_sigma = scale_by_sigma
89
+
90
+ fir = fir
91
+ fir_kernel = fir_kernel
92
+ self.skip_rescale = skip_rescale = skip_rescale
93
+ self.resblock_type = resblock_type = resblock_type.lower()
94
+ self.progressive = progressive = progressive.lower()
95
+ self.progressive_input = progressive_input = progressive_input.lower()
96
+ self.embedding_type = embedding_type = embedding_type.lower()
97
+ init_scale = init_scale
98
+ assert progressive in ['none', 'output_skip', 'residual']
99
+ assert progressive_input in ['none', 'input_skip', 'residual']
100
+ assert embedding_type in ['fourier', 'positional']
101
+ combine_method = progressive_combine.lower()
102
+ combiner = functools.partial(Combine, method=combine_method)
103
+
104
+ num_channels = 4 # x.real, x.imag, y.real, y.imag
105
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
106
+
107
+ modules = []
108
+ # timestep/noise_level embedding
109
+ if embedding_type == 'fourier':
110
+ # Gaussian Fourier features embeddings.
111
+ modules.append(layerspp.GaussianFourierProjection(
112
+ embedding_size=nf, scale=fourier_scale
113
+ ))
114
+ embed_dim = 2 * nf
115
+ elif embedding_type == 'positional':
116
+ embed_dim = nf
117
+ else:
118
+ raise ValueError(f'embedding type {embedding_type} unknown.')
119
+
120
+ if conditional:
121
+ modules.append(nn.Linear(embed_dim, nf * 4))
122
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
123
+ nn.init.zeros_(modules[-1].bias)
124
+ modules.append(nn.Linear(nf * 4, nf * 4))
125
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
126
+ nn.init.zeros_(modules[-1].bias)
127
+
128
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
129
+ init_scale=init_scale, skip_rescale=skip_rescale)
130
+
131
+ Upsample = functools.partial(layerspp.Upsample,
132
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
133
+
134
+ if progressive == 'output_skip':
135
+ self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
136
+ elif progressive == 'residual':
137
+ pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
138
+ fir_kernel=fir_kernel, with_conv=True)
139
+
140
+ Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
141
+
142
+ if progressive_input == 'input_skip':
143
+ self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
144
+ elif progressive_input == 'residual':
145
+ pyramid_downsample = functools.partial(layerspp.Downsample,
146
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
147
+
148
+ if resblock_type == 'ddpm':
149
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
150
+ dropout=dropout, init_scale=init_scale,
151
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
152
+
153
+ elif resblock_type == 'biggan':
154
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
155
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
156
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
157
+
158
+ else:
159
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
160
+
161
+ # Downsampling block
162
+
163
+ channels = num_channels
164
+ if progressive_input != 'none':
165
+ input_pyramid_ch = channels
166
+
167
+ modules.append(conv3x3(channels, nf))
168
+ hs_c = [nf]
169
+
170
+ in_ch = nf
171
+ for i_level in range(num_resolutions):
172
+ # Residual blocks for this resolution
173
+ for i_block in range(num_res_blocks):
174
+ out_ch = nf * ch_mult[i_level]
175
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
176
+ in_ch = out_ch
177
+
178
+ if all_resolutions[i_level] in attn_resolutions:
179
+ modules.append(AttnBlock(channels=in_ch))
180
+ hs_c.append(in_ch)
181
+
182
+ if i_level != num_resolutions - 1:
183
+ if resblock_type == 'ddpm':
184
+ modules.append(Downsample(in_ch=in_ch))
185
+ else:
186
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
187
+
188
+ if progressive_input == 'input_skip':
189
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
190
+ if combine_method == 'cat':
191
+ in_ch *= 2
192
+
193
+ elif progressive_input == 'residual':
194
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
195
+ input_pyramid_ch = in_ch
196
+
197
+ hs_c.append(in_ch)
198
+
199
+ in_ch = hs_c[-1]
200
+ modules.append(ResnetBlock(in_ch=in_ch))
201
+ modules.append(AttnBlock(channels=in_ch))
202
+ modules.append(ResnetBlock(in_ch=in_ch))
203
+
204
+ pyramid_ch = 0
205
+ # Upsampling block
206
+ for i_level in reversed(range(num_resolutions)):
207
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
208
+ out_ch = nf * ch_mult[i_level]
209
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
210
+ in_ch = out_ch
211
+
212
+ if all_resolutions[i_level] in attn_resolutions:
213
+ modules.append(AttnBlock(channels=in_ch))
214
+
215
+ if progressive != 'none':
216
+ if i_level == num_resolutions - 1:
217
+ if progressive == 'output_skip':
218
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
219
+ num_channels=in_ch, eps=1e-6))
220
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
221
+ pyramid_ch = channels
222
+ elif progressive == 'residual':
223
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
224
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
225
+ pyramid_ch = in_ch
226
+ else:
227
+ raise ValueError(f'{progressive} is not a valid name.')
228
+ else:
229
+ if progressive == 'output_skip':
230
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
231
+ num_channels=in_ch, eps=1e-6))
232
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
233
+ pyramid_ch = channels
234
+ elif progressive == 'residual':
235
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
236
+ pyramid_ch = in_ch
237
+ else:
238
+ raise ValueError(f'{progressive} is not a valid name')
239
+
240
+ if i_level != 0:
241
+ if resblock_type == 'ddpm':
242
+ modules.append(Upsample(in_ch=in_ch))
243
+ else:
244
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
245
+
246
+ assert not hs_c
247
+
248
+ if progressive != 'output_skip':
249
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
250
+ num_channels=in_ch, eps=1e-6))
251
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
252
+
253
+ self.all_modules = nn.ModuleList(modules)
254
+
255
+
256
+ def forward(self, x, time_cond):
257
+ # timestep/noise_level embedding; only for continuous training
258
+ modules = self.all_modules
259
+ m_idx = 0
260
+
261
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
262
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
263
+ x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
264
+
265
+ if self.embedding_type == 'fourier':
266
+ # Gaussian Fourier features embeddings.
267
+ used_sigmas = time_cond
268
+ temb = modules[m_idx](torch.log(used_sigmas))
269
+ m_idx += 1
270
+
271
+ elif self.embedding_type == 'positional':
272
+ # Sinusoidal positional embeddings.
273
+ timesteps = time_cond
274
+ used_sigmas = self.sigmas[time_cond.long()]
275
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
276
+
277
+ else:
278
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
279
+
280
+ if self.conditional:
281
+ temb = modules[m_idx](temb)
282
+ m_idx += 1
283
+ temb = modules[m_idx](self.act(temb))
284
+ m_idx += 1
285
+ else:
286
+ temb = None
287
+
288
+ if not self.centered:
289
+ # If input data is in [0, 1]
290
+ x = 2 * x - 1.
291
+
292
+ # Downsampling block
293
+ input_pyramid = None
294
+ if self.progressive_input != 'none':
295
+ input_pyramid = x
296
+
297
+ # Input layer: Conv2d: 4ch -> 128ch
298
+ hs = [modules[m_idx](x)]
299
+ m_idx += 1
300
+
301
+ # Down path in U-Net
302
+ for i_level in range(self.num_resolutions):
303
+ # Residual blocks for this resolution
304
+ for i_block in range(self.num_res_blocks):
305
+ h = modules[m_idx](hs[-1], temb)
306
+ m_idx += 1
307
+ # Attention layer (optional)
308
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
309
+ h = modules[m_idx](h)
310
+ m_idx += 1
311
+ hs.append(h)
312
+
313
+ # Downsampling
314
+ if i_level != self.num_resolutions - 1:
315
+ if self.resblock_type == 'ddpm':
316
+ h = modules[m_idx](hs[-1])
317
+ m_idx += 1
318
+ else:
319
+ h = modules[m_idx](hs[-1], temb)
320
+ m_idx += 1
321
+
322
+ if self.progressive_input == 'input_skip': # Combine h with x
323
+ input_pyramid = self.pyramid_downsample(input_pyramid)
324
+ h = modules[m_idx](input_pyramid, h)
325
+ m_idx += 1
326
+
327
+ elif self.progressive_input == 'residual':
328
+ input_pyramid = modules[m_idx](input_pyramid)
329
+ m_idx += 1
330
+ if self.skip_rescale:
331
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
332
+ else:
333
+ input_pyramid = input_pyramid + h
334
+ h = input_pyramid
335
+ hs.append(h)
336
+
337
+ h = hs[-1] # actualy equal to: h = h
338
+ h = modules[m_idx](h, temb) # ResNet block
339
+ m_idx += 1
340
+ h = modules[m_idx](h) # Attention block
341
+ m_idx += 1
342
+ h = modules[m_idx](h, temb) # ResNet block
343
+ m_idx += 1
344
+
345
+ pyramid = None
346
+
347
+ # Upsampling block
348
+ for i_level in reversed(range(self.num_resolutions)):
349
+ for i_block in range(self.num_res_blocks + 1):
350
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
351
+ m_idx += 1
352
+
353
+ # edit: from -1 to -2
354
+ if h.shape[-2] in self.attn_resolutions:
355
+ h = modules[m_idx](h)
356
+ m_idx += 1
357
+
358
+ if self.progressive != 'none':
359
+ if i_level == self.num_resolutions - 1:
360
+ if self.progressive == 'output_skip':
361
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
362
+ m_idx += 1
363
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
364
+ m_idx += 1
365
+ elif self.progressive == 'residual':
366
+ pyramid = self.act(modules[m_idx](h))
367
+ m_idx += 1
368
+ pyramid = modules[m_idx](pyramid)
369
+ m_idx += 1
370
+ else:
371
+ raise ValueError(f'{self.progressive} is not a valid name.')
372
+ else:
373
+ if self.progressive == 'output_skip':
374
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
375
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
376
+ m_idx += 1
377
+ pyramid_h = modules[m_idx](pyramid_h)
378
+ m_idx += 1
379
+ pyramid = pyramid + pyramid_h
380
+ elif self.progressive == 'residual':
381
+ pyramid = modules[m_idx](pyramid)
382
+ m_idx += 1
383
+ if self.skip_rescale:
384
+ pyramid = (pyramid + h) / np.sqrt(2.)
385
+ else:
386
+ pyramid = pyramid + h
387
+ h = pyramid
388
+ else:
389
+ raise ValueError(f'{self.progressive} is not a valid name')
390
+
391
+ # Upsampling Layer
392
+ if i_level != 0:
393
+ if self.resblock_type == 'ddpm':
394
+ h = modules[m_idx](h)
395
+ m_idx += 1
396
+ else:
397
+ h = modules[m_idx](h, temb) # Upspampling
398
+ m_idx += 1
399
+
400
+ assert not hs
401
+
402
+ if self.progressive == 'output_skip':
403
+ h = pyramid
404
+ else:
405
+ h = self.act(modules[m_idx](h))
406
+ m_idx += 1
407
+ h = modules[m_idx](h)
408
+ m_idx += 1
409
+
410
+ assert m_idx == len(modules), "Implementation error"
411
+ if self.scale_by_sigma:
412
+ used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
413
+ h = h / used_sigmas
414
+
415
+ # Convert back to complex number
416
+ h = self.output_layer(h)
417
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
418
+ h = torch.view_as_complex(h)[:,None, :, :]
419
+ return h
sgmse/backbones/ncsnpp_48k.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp_48k")
37
+ class NCSNpp_48k(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
+ parser.add_argument("--num_res_blocks", type=int, default=2)
44
+ parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
45
+ parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
46
+ parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
47
+ parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
48
+ parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
49
+ parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
50
+ parser.set_defaults(centered=True)
51
+ return parser
52
+
53
+ def __init__(self,
54
+ scale_by_sigma = True,
55
+ nonlinearity = 'swish',
56
+ nf = 128,
57
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
58
+ num_res_blocks = 2,
59
+ attn_resolutions = (),
60
+ resamp_with_conv = True,
61
+ conditional = True,
62
+ fir = True,
63
+ fir_kernel = [1, 3, 3, 1],
64
+ skip_rescale = True,
65
+ resblock_type = 'biggan',
66
+ progressive = 'none',
67
+ progressive_input = 'none',
68
+ progressive_combine = 'sum',
69
+ init_scale = 0.,
70
+ fourier_scale = 16,
71
+ image_size = 256,
72
+ embedding_type = 'fourier',
73
+ dropout = .0,
74
+ centered = True,
75
+ **unused_kwargs
76
+ ):
77
+ super().__init__()
78
+ self.act = act = get_act(nonlinearity)
79
+
80
+ self.nf = nf = nf
81
+ ch_mult = ch_mult
82
+ self.num_res_blocks = num_res_blocks = num_res_blocks
83
+ self.attn_resolutions = attn_resolutions
84
+ dropout = dropout
85
+ resamp_with_conv = resamp_with_conv
86
+ self.num_resolutions = num_resolutions = len(ch_mult)
87
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
88
+
89
+ self.conditional = conditional = conditional # noise-conditional
90
+ self.centered = centered
91
+ self.scale_by_sigma = scale_by_sigma
92
+
93
+ fir = fir
94
+ fir_kernel = fir_kernel
95
+ self.skip_rescale = skip_rescale = skip_rescale
96
+ self.resblock_type = resblock_type = resblock_type.lower()
97
+ self.progressive = progressive = progressive.lower()
98
+ self.progressive_input = progressive_input = progressive_input.lower()
99
+ self.embedding_type = embedding_type = embedding_type.lower()
100
+ init_scale = init_scale
101
+ assert progressive in ['none', 'output_skip', 'residual']
102
+ assert progressive_input in ['none', 'input_skip', 'residual']
103
+ assert embedding_type in ['fourier', 'positional']
104
+ combine_method = progressive_combine.lower()
105
+ combiner = functools.partial(Combine, method=combine_method)
106
+
107
+ num_channels = 4 # x.real, x.imag, y.real, y.imag
108
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
109
+
110
+ modules = []
111
+ # timestep/noise_level embedding
112
+ if embedding_type == 'fourier':
113
+ # Gaussian Fourier features embeddings.
114
+ modules.append(layerspp.GaussianFourierProjection(
115
+ embedding_size=nf, scale=fourier_scale
116
+ ))
117
+ embed_dim = 2 * nf
118
+ elif embedding_type == 'positional':
119
+ embed_dim = nf
120
+ else:
121
+ raise ValueError(f'embedding type {embedding_type} unknown.')
122
+
123
+ if conditional:
124
+ modules.append(nn.Linear(embed_dim, nf * 4))
125
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
126
+ nn.init.zeros_(modules[-1].bias)
127
+ modules.append(nn.Linear(nf * 4, nf * 4))
128
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
129
+ nn.init.zeros_(modules[-1].bias)
130
+
131
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
132
+ init_scale=init_scale, skip_rescale=skip_rescale)
133
+
134
+ Upsample = functools.partial(layerspp.Upsample,
135
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
136
+
137
+ if progressive == 'output_skip':
138
+ self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
139
+ elif progressive == 'residual':
140
+ pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
141
+ fir_kernel=fir_kernel, with_conv=True)
142
+
143
+ Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
144
+
145
+ if progressive_input == 'input_skip':
146
+ self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
147
+ elif progressive_input == 'residual':
148
+ pyramid_downsample = functools.partial(layerspp.Downsample,
149
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
150
+
151
+ if resblock_type == 'ddpm':
152
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
153
+ dropout=dropout, init_scale=init_scale,
154
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
155
+
156
+ elif resblock_type == 'biggan':
157
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
158
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
159
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
160
+
161
+ else:
162
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
163
+
164
+ # Downsampling block
165
+
166
+ channels = num_channels
167
+ if progressive_input != 'none':
168
+ input_pyramid_ch = channels
169
+
170
+ modules.append(conv3x3(channels, nf))
171
+ hs_c = [nf]
172
+
173
+ in_ch = nf
174
+ for i_level in range(num_resolutions):
175
+ # Residual blocks for this resolution
176
+ for i_block in range(num_res_blocks):
177
+ out_ch = nf * ch_mult[i_level]
178
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
179
+ in_ch = out_ch
180
+
181
+ if all_resolutions[i_level] in attn_resolutions:
182
+ modules.append(AttnBlock(channels=in_ch))
183
+ hs_c.append(in_ch)
184
+
185
+ if i_level != num_resolutions - 1:
186
+ if resblock_type == 'ddpm':
187
+ modules.append(Downsample(in_ch=in_ch))
188
+ else:
189
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
190
+
191
+ if progressive_input == 'input_skip':
192
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
193
+ if combine_method == 'cat':
194
+ in_ch *= 2
195
+
196
+ elif progressive_input == 'residual':
197
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
198
+ input_pyramid_ch = in_ch
199
+
200
+ hs_c.append(in_ch)
201
+
202
+ in_ch = hs_c[-1]
203
+ modules.append(ResnetBlock(in_ch=in_ch))
204
+ modules.append(AttnBlock(channels=in_ch))
205
+ modules.append(ResnetBlock(in_ch=in_ch))
206
+
207
+ pyramid_ch = 0
208
+ # Upsampling block
209
+ for i_level in reversed(range(num_resolutions)):
210
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
211
+ out_ch = nf * ch_mult[i_level]
212
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
213
+ in_ch = out_ch
214
+
215
+ if all_resolutions[i_level] in attn_resolutions:
216
+ modules.append(AttnBlock(channels=in_ch))
217
+
218
+ if progressive != 'none':
219
+ if i_level == num_resolutions - 1:
220
+ if progressive == 'output_skip':
221
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
222
+ num_channels=in_ch, eps=1e-6))
223
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
224
+ pyramid_ch = channels
225
+ elif progressive == 'residual':
226
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
227
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
228
+ pyramid_ch = in_ch
229
+ else:
230
+ raise ValueError(f'{progressive} is not a valid name.')
231
+ else:
232
+ if progressive == 'output_skip':
233
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
234
+ num_channels=in_ch, eps=1e-6))
235
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
236
+ pyramid_ch = channels
237
+ elif progressive == 'residual':
238
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
239
+ pyramid_ch = in_ch
240
+ else:
241
+ raise ValueError(f'{progressive} is not a valid name')
242
+
243
+ if i_level != 0:
244
+ if resblock_type == 'ddpm':
245
+ modules.append(Upsample(in_ch=in_ch))
246
+ else:
247
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
248
+
249
+ assert not hs_c
250
+
251
+ if progressive != 'output_skip':
252
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
253
+ num_channels=in_ch, eps=1e-6))
254
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
255
+
256
+ self.all_modules = nn.ModuleList(modules)
257
+
258
+
259
+ def forward(self, x, time_cond):
260
+ # timestep/noise_level embedding; only for continuous training
261
+ modules = self.all_modules
262
+ m_idx = 0
263
+
264
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
265
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
266
+ x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
267
+
268
+ if self.embedding_type == 'fourier':
269
+ # Gaussian Fourier features embeddings.
270
+ used_sigmas = time_cond
271
+ temb = modules[m_idx](torch.log(used_sigmas))
272
+ m_idx += 1
273
+
274
+ elif self.embedding_type == 'positional':
275
+ # Sinusoidal positional embeddings.
276
+ timesteps = time_cond
277
+ used_sigmas = self.sigmas[time_cond.long()]
278
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
279
+
280
+ else:
281
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
282
+
283
+ if self.conditional:
284
+ temb = modules[m_idx](temb)
285
+ m_idx += 1
286
+ temb = modules[m_idx](self.act(temb))
287
+ m_idx += 1
288
+ else:
289
+ temb = None
290
+
291
+ if not self.centered:
292
+ # If input data is in [0, 1]
293
+ x = 2 * x - 1.
294
+
295
+ # Downsampling block
296
+ input_pyramid = None
297
+ if self.progressive_input != 'none':
298
+ input_pyramid = x
299
+
300
+ # Input layer: Conv2d: 4ch -> 128ch
301
+ hs = [modules[m_idx](x)]
302
+ m_idx += 1
303
+
304
+ # Down path in U-Net
305
+ for i_level in range(self.num_resolutions):
306
+ # Residual blocks for this resolution
307
+ for i_block in range(self.num_res_blocks):
308
+ h = modules[m_idx](hs[-1], temb)
309
+ m_idx += 1
310
+ # Attention layer (optional)
311
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
312
+ h = modules[m_idx](h)
313
+ m_idx += 1
314
+ hs.append(h)
315
+
316
+ # Downsampling
317
+ if i_level != self.num_resolutions - 1:
318
+ if self.resblock_type == 'ddpm':
319
+ h = modules[m_idx](hs[-1])
320
+ m_idx += 1
321
+ else:
322
+ h = modules[m_idx](hs[-1], temb)
323
+ m_idx += 1
324
+
325
+ if self.progressive_input == 'input_skip': # Combine h with x
326
+ input_pyramid = self.pyramid_downsample(input_pyramid)
327
+ h = modules[m_idx](input_pyramid, h)
328
+ m_idx += 1
329
+
330
+ elif self.progressive_input == 'residual':
331
+ input_pyramid = modules[m_idx](input_pyramid)
332
+ m_idx += 1
333
+ if self.skip_rescale:
334
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
335
+ else:
336
+ input_pyramid = input_pyramid + h
337
+ h = input_pyramid
338
+ hs.append(h)
339
+
340
+ h = hs[-1] # actualy equal to: h = h
341
+ h = modules[m_idx](h, temb) # ResNet block
342
+ m_idx += 1
343
+ h = modules[m_idx](h) # Attention block
344
+ m_idx += 1
345
+ h = modules[m_idx](h, temb) # ResNet block
346
+ m_idx += 1
347
+
348
+ pyramid = None
349
+
350
+ # Upsampling block
351
+ for i_level in reversed(range(self.num_resolutions)):
352
+ for i_block in range(self.num_res_blocks + 1):
353
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
354
+ m_idx += 1
355
+
356
+ # edit: from -1 to -2
357
+ if h.shape[-2] in self.attn_resolutions:
358
+ h = modules[m_idx](h)
359
+ m_idx += 1
360
+
361
+ if self.progressive != 'none':
362
+ if i_level == self.num_resolutions - 1:
363
+ if self.progressive == 'output_skip':
364
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
365
+ m_idx += 1
366
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
367
+ m_idx += 1
368
+ elif self.progressive == 'residual':
369
+ pyramid = self.act(modules[m_idx](h))
370
+ m_idx += 1
371
+ pyramid = modules[m_idx](pyramid)
372
+ m_idx += 1
373
+ else:
374
+ raise ValueError(f'{self.progressive} is not a valid name.')
375
+ else:
376
+ if self.progressive == 'output_skip':
377
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
378
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
379
+ m_idx += 1
380
+ pyramid_h = modules[m_idx](pyramid_h)
381
+ m_idx += 1
382
+ pyramid = pyramid + pyramid_h
383
+ elif self.progressive == 'residual':
384
+ pyramid = modules[m_idx](pyramid)
385
+ m_idx += 1
386
+ if self.skip_rescale:
387
+ pyramid = (pyramid + h) / np.sqrt(2.)
388
+ else:
389
+ pyramid = pyramid + h
390
+ h = pyramid
391
+ else:
392
+ raise ValueError(f'{self.progressive} is not a valid name')
393
+
394
+ # Upsampling Layer
395
+ if i_level != 0:
396
+ if self.resblock_type == 'ddpm':
397
+ h = modules[m_idx](h)
398
+ m_idx += 1
399
+ else:
400
+ h = modules[m_idx](h, temb) # Upspampling
401
+ m_idx += 1
402
+
403
+ assert not hs
404
+
405
+ if self.progressive == 'output_skip':
406
+ h = pyramid
407
+ else:
408
+ h = self.act(modules[m_idx](h))
409
+ m_idx += 1
410
+ h = modules[m_idx](h)
411
+ m_idx += 1
412
+
413
+ assert m_idx == len(modules), "Implementation error"
414
+
415
+ # Convert back to complex number
416
+ h = self.output_layer(h)
417
+
418
+ if self.scale_by_sigma:
419
+ used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
420
+ h = h / used_sigmas
421
+
422
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
423
+ h = torch.view_as_complex(h)[:,None, :, :]
424
+ return h
sgmse/backbones/ncsnpp_utils/layers.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Common layers for defining score networks.
18
+ """
19
+ import math
20
+ import string
21
+ from functools import partial
22
+ import torch.nn as nn
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import numpy as np
26
+ from .normalization import ConditionalInstanceNorm2dPlus
27
+
28
+
29
+ def get_act(config):
30
+ """Get activation functions from the config file."""
31
+
32
+ if config == 'elu':
33
+ return nn.ELU()
34
+ elif config == 'relu':
35
+ return nn.ReLU()
36
+ elif config == 'lrelu':
37
+ return nn.LeakyReLU(negative_slope=0.2)
38
+ elif config == 'swish':
39
+ return nn.SiLU()
40
+ else:
41
+ raise NotImplementedError('activation function does not exist!')
42
+
43
+
44
+ def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
45
+ """1x1 convolution. Same as NCSNv1/v2."""
46
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
47
+ padding=padding)
48
+ init_scale = 1e-10 if init_scale == 0 else init_scale
49
+ conv.weight.data *= init_scale
50
+ conv.bias.data *= init_scale
51
+ return conv
52
+
53
+
54
+ def variance_scaling(scale, mode, distribution,
55
+ in_axis=1, out_axis=0,
56
+ dtype=torch.float32,
57
+ device='cpu'):
58
+ """Ported from JAX. """
59
+
60
+ def _compute_fans(shape, in_axis=1, out_axis=0):
61
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
62
+ fan_in = shape[in_axis] * receptive_field_size
63
+ fan_out = shape[out_axis] * receptive_field_size
64
+ return fan_in, fan_out
65
+
66
+ def init(shape, dtype=dtype, device=device):
67
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
68
+ if mode == "fan_in":
69
+ denominator = fan_in
70
+ elif mode == "fan_out":
71
+ denominator = fan_out
72
+ elif mode == "fan_avg":
73
+ denominator = (fan_in + fan_out) / 2
74
+ else:
75
+ raise ValueError(
76
+ "invalid mode for variance scaling initializer: {}".format(mode))
77
+ variance = scale / denominator
78
+ if distribution == "normal":
79
+ return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
80
+ elif distribution == "uniform":
81
+ return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
82
+ else:
83
+ raise ValueError("invalid distribution for variance scaling initializer")
84
+
85
+ return init
86
+
87
+
88
+ def default_init(scale=1.):
89
+ """The same initialization used in DDPM."""
90
+ scale = 1e-10 if scale == 0 else scale
91
+ return variance_scaling(scale, 'fan_avg', 'uniform')
92
+
93
+
94
+ class Dense(nn.Module):
95
+ """Linear layer with `default_init`."""
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+
100
+ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
101
+ """1x1 convolution with DDPM initialization."""
102
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
103
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
104
+ nn.init.zeros_(conv.bias)
105
+ return conv
106
+
107
+
108
+ def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
109
+ """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
110
+ init_scale = 1e-10 if init_scale == 0 else init_scale
111
+ conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
112
+ dilation=dilation, padding=padding, kernel_size=3)
113
+ conv.weight.data *= init_scale
114
+ conv.bias.data *= init_scale
115
+ return conv
116
+
117
+
118
+ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
119
+ """3x3 convolution with DDPM initialization."""
120
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
121
+ dilation=dilation, bias=bias)
122
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
123
+ nn.init.zeros_(conv.bias)
124
+ return conv
125
+
126
+ ###########################################################################
127
+ # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
128
+ # https://github.com/ermongroup/ncsn
129
+ # https://github.com/ermongroup/ncsnv2
130
+ ###########################################################################
131
+
132
+
133
+ class CRPBlock(nn.Module):
134
+ def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
135
+ super().__init__()
136
+ self.convs = nn.ModuleList()
137
+ for i in range(n_stages):
138
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
139
+ self.n_stages = n_stages
140
+ if maxpool:
141
+ self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
142
+ else:
143
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
144
+
145
+ self.act = act
146
+
147
+ def forward(self, x):
148
+ x = self.act(x)
149
+ path = x
150
+ for i in range(self.n_stages):
151
+ path = self.pool(path)
152
+ path = self.convs[i](path)
153
+ x = path + x
154
+ return x
155
+
156
+
157
+ class CondCRPBlock(nn.Module):
158
+ def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
159
+ super().__init__()
160
+ self.convs = nn.ModuleList()
161
+ self.norms = nn.ModuleList()
162
+ self.normalizer = normalizer
163
+ for i in range(n_stages):
164
+ self.norms.append(normalizer(features, num_classes, bias=True))
165
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
166
+
167
+ self.n_stages = n_stages
168
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
169
+ self.act = act
170
+
171
+ def forward(self, x, y):
172
+ x = self.act(x)
173
+ path = x
174
+ for i in range(self.n_stages):
175
+ path = self.norms[i](path, y)
176
+ path = self.pool(path)
177
+ path = self.convs[i](path)
178
+
179
+ x = path + x
180
+ return x
181
+
182
+
183
+ class RCUBlock(nn.Module):
184
+ def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
185
+ super().__init__()
186
+
187
+ for i in range(n_blocks):
188
+ for j in range(n_stages):
189
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
190
+
191
+ self.stride = 1
192
+ self.n_blocks = n_blocks
193
+ self.n_stages = n_stages
194
+ self.act = act
195
+
196
+ def forward(self, x):
197
+ for i in range(self.n_blocks):
198
+ residual = x
199
+ for j in range(self.n_stages):
200
+ x = self.act(x)
201
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
202
+
203
+ x += residual
204
+ return x
205
+
206
+
207
+ class CondRCUBlock(nn.Module):
208
+ def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
209
+ super().__init__()
210
+
211
+ for i in range(n_blocks):
212
+ for j in range(n_stages):
213
+ setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
214
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
215
+
216
+ self.stride = 1
217
+ self.n_blocks = n_blocks
218
+ self.n_stages = n_stages
219
+ self.act = act
220
+ self.normalizer = normalizer
221
+
222
+ def forward(self, x, y):
223
+ for i in range(self.n_blocks):
224
+ residual = x
225
+ for j in range(self.n_stages):
226
+ x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
227
+ x = self.act(x)
228
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
229
+
230
+ x += residual
231
+ return x
232
+
233
+
234
+ class MSFBlock(nn.Module):
235
+ def __init__(self, in_planes, features):
236
+ super().__init__()
237
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
238
+ self.convs = nn.ModuleList()
239
+ self.features = features
240
+
241
+ for i in range(len(in_planes)):
242
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
243
+
244
+ def forward(self, xs, shape):
245
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
246
+ for i in range(len(self.convs)):
247
+ h = self.convs[i](xs[i])
248
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
249
+ sums += h
250
+ return sums
251
+
252
+
253
+ class CondMSFBlock(nn.Module):
254
+ def __init__(self, in_planes, features, num_classes, normalizer):
255
+ super().__init__()
256
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
257
+
258
+ self.convs = nn.ModuleList()
259
+ self.norms = nn.ModuleList()
260
+ self.features = features
261
+ self.normalizer = normalizer
262
+
263
+ for i in range(len(in_planes)):
264
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
265
+ self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
266
+
267
+ def forward(self, xs, y, shape):
268
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
269
+ for i in range(len(self.convs)):
270
+ h = self.norms[i](xs[i], y)
271
+ h = self.convs[i](h)
272
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
273
+ sums += h
274
+ return sums
275
+
276
+
277
+ class RefineBlock(nn.Module):
278
+ def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
279
+ super().__init__()
280
+
281
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
282
+ self.n_blocks = n_blocks = len(in_planes)
283
+
284
+ self.adapt_convs = nn.ModuleList()
285
+ for i in range(n_blocks):
286
+ self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
287
+
288
+ self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
289
+
290
+ if not start:
291
+ self.msf = MSFBlock(in_planes, features)
292
+
293
+ self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
294
+
295
+ def forward(self, xs, output_shape):
296
+ assert isinstance(xs, tuple) or isinstance(xs, list)
297
+ hs = []
298
+ for i in range(len(xs)):
299
+ h = self.adapt_convs[i](xs[i])
300
+ hs.append(h)
301
+
302
+ if self.n_blocks > 1:
303
+ h = self.msf(hs, output_shape)
304
+ else:
305
+ h = hs[0]
306
+
307
+ h = self.crp(h)
308
+ h = self.output_convs(h)
309
+
310
+ return h
311
+
312
+
313
+ class CondRefineBlock(nn.Module):
314
+ def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
315
+ super().__init__()
316
+
317
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
318
+ self.n_blocks = n_blocks = len(in_planes)
319
+
320
+ self.adapt_convs = nn.ModuleList()
321
+ for i in range(n_blocks):
322
+ self.adapt_convs.append(
323
+ CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
324
+ )
325
+
326
+ self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
327
+
328
+ if not start:
329
+ self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
330
+
331
+ self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
332
+
333
+ def forward(self, xs, y, output_shape):
334
+ assert isinstance(xs, tuple) or isinstance(xs, list)
335
+ hs = []
336
+ for i in range(len(xs)):
337
+ h = self.adapt_convs[i](xs[i], y)
338
+ hs.append(h)
339
+
340
+ if self.n_blocks > 1:
341
+ h = self.msf(hs, y, output_shape)
342
+ else:
343
+ h = hs[0]
344
+
345
+ h = self.crp(h, y)
346
+ h = self.output_convs(h, y)
347
+
348
+ return h
349
+
350
+
351
+ class ConvMeanPool(nn.Module):
352
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
353
+ super().__init__()
354
+ if not adjust_padding:
355
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
356
+ self.conv = conv
357
+ else:
358
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
359
+
360
+ self.conv = nn.Sequential(
361
+ nn.ZeroPad2d((1, 0, 1, 0)),
362
+ conv
363
+ )
364
+
365
+ def forward(self, inputs):
366
+ output = self.conv(inputs)
367
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
368
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
369
+ return output
370
+
371
+
372
+ class MeanPoolConv(nn.Module):
373
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
374
+ super().__init__()
375
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
376
+
377
+ def forward(self, inputs):
378
+ output = inputs
379
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
380
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
381
+ return self.conv(output)
382
+
383
+
384
+ class UpsampleConv(nn.Module):
385
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
386
+ super().__init__()
387
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
388
+ self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
389
+
390
+ def forward(self, inputs):
391
+ output = inputs
392
+ output = torch.cat([output, output, output, output], dim=1)
393
+ output = self.pixelshuffle(output)
394
+ return self.conv(output)
395
+
396
+
397
+ class ConditionalResidualBlock(nn.Module):
398
+ def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
399
+ normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
400
+ super().__init__()
401
+ self.non_linearity = act
402
+ self.input_dim = input_dim
403
+ self.output_dim = output_dim
404
+ self.resample = resample
405
+ self.normalization = normalization
406
+ if resample == 'down':
407
+ if dilation > 1:
408
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
409
+ self.normalize2 = normalization(input_dim, num_classes)
410
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
411
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
412
+ else:
413
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
414
+ self.normalize2 = normalization(input_dim, num_classes)
415
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
416
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
417
+
418
+ elif resample is None:
419
+ if dilation > 1:
420
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
421
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
422
+ self.normalize2 = normalization(output_dim, num_classes)
423
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
424
+ else:
425
+ conv_shortcut = nn.Conv2d
426
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
427
+ self.normalize2 = normalization(output_dim, num_classes)
428
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
429
+ else:
430
+ raise Exception('invalid resample value')
431
+
432
+ if output_dim != input_dim or resample is not None:
433
+ self.shortcut = conv_shortcut(input_dim, output_dim)
434
+
435
+ self.normalize1 = normalization(input_dim, num_classes)
436
+
437
+ def forward(self, x, y):
438
+ output = self.normalize1(x, y)
439
+ output = self.non_linearity(output)
440
+ output = self.conv1(output)
441
+ output = self.normalize2(output, y)
442
+ output = self.non_linearity(output)
443
+ output = self.conv2(output)
444
+
445
+ if self.output_dim == self.input_dim and self.resample is None:
446
+ shortcut = x
447
+ else:
448
+ shortcut = self.shortcut(x)
449
+
450
+ return shortcut + output
451
+
452
+
453
+ class ResidualBlock(nn.Module):
454
+ def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
455
+ normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
456
+ super().__init__()
457
+ self.non_linearity = act
458
+ self.input_dim = input_dim
459
+ self.output_dim = output_dim
460
+ self.resample = resample
461
+ self.normalization = normalization
462
+ if resample == 'down':
463
+ if dilation > 1:
464
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
465
+ self.normalize2 = normalization(input_dim)
466
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
467
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
468
+ else:
469
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
470
+ self.normalize2 = normalization(input_dim)
471
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
472
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
473
+
474
+ elif resample is None:
475
+ if dilation > 1:
476
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
477
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
478
+ self.normalize2 = normalization(output_dim)
479
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
480
+ else:
481
+ # conv_shortcut = nn.Conv2d ### Something wierd here.
482
+ conv_shortcut = partial(ncsn_conv1x1)
483
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
484
+ self.normalize2 = normalization(output_dim)
485
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
486
+ else:
487
+ raise Exception('invalid resample value')
488
+
489
+ if output_dim != input_dim or resample is not None:
490
+ self.shortcut = conv_shortcut(input_dim, output_dim)
491
+
492
+ self.normalize1 = normalization(input_dim)
493
+
494
+ def forward(self, x):
495
+ output = self.normalize1(x)
496
+ output = self.non_linearity(output)
497
+ output = self.conv1(output)
498
+ output = self.normalize2(output)
499
+ output = self.non_linearity(output)
500
+ output = self.conv2(output)
501
+
502
+ if self.output_dim == self.input_dim and self.resample is None:
503
+ shortcut = x
504
+ else:
505
+ shortcut = self.shortcut(x)
506
+
507
+ return shortcut + output
508
+
509
+
510
+ ###########################################################################
511
+ # Functions below are ported over from the DDPM codebase:
512
+ # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
513
+ ###########################################################################
514
+
515
+ def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
516
+ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
517
+ half_dim = embedding_dim // 2
518
+ # magic number 10000 is from transformers
519
+ emb = math.log(max_positions) / (half_dim - 1)
520
+ # emb = math.log(2.) / (half_dim - 1)
521
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
522
+ # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
523
+ # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
524
+ emb = timesteps.float()[:, None] * emb[None, :]
525
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
526
+ if embedding_dim % 2 == 1: # zero pad
527
+ emb = F.pad(emb, (0, 1), mode='constant')
528
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
529
+ return emb
530
+
531
+
532
+ def _einsum(a, b, c, x, y):
533
+ einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
534
+ return torch.einsum(einsum_str, x, y)
535
+
536
+
537
+ def contract_inner(x, y):
538
+ """tensordot(x, y, 1)."""
539
+ x_chars = list(string.ascii_lowercase[:len(x.shape)])
540
+ y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
541
+ y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
542
+ out_chars = x_chars[:-1] + y_chars[1:]
543
+ return _einsum(x_chars, y_chars, out_chars, x, y)
544
+
545
+
546
+ class NIN(nn.Module):
547
+ def __init__(self, in_dim, num_units, init_scale=0.1):
548
+ super().__init__()
549
+ self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
550
+ self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
551
+
552
+ def forward(self, x):
553
+ x = x.permute(0, 2, 3, 1)
554
+ y = contract_inner(x, self.W) + self.b
555
+ return y.permute(0, 3, 1, 2)
556
+
557
+
558
+ class AttnBlock(nn.Module):
559
+ """Channel-wise self-attention block."""
560
+ def __init__(self, channels):
561
+ super().__init__()
562
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
563
+ self.NIN_0 = NIN(channels, channels)
564
+ self.NIN_1 = NIN(channels, channels)
565
+ self.NIN_2 = NIN(channels, channels)
566
+ self.NIN_3 = NIN(channels, channels, init_scale=0.)
567
+
568
+ def forward(self, x):
569
+ B, C, H, W = x.shape
570
+ h = self.GroupNorm_0(x)
571
+ q = self.NIN_0(h)
572
+ k = self.NIN_1(h)
573
+ v = self.NIN_2(h)
574
+
575
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
576
+ w = torch.reshape(w, (B, H, W, H * W))
577
+ w = F.softmax(w, dim=-1)
578
+ w = torch.reshape(w, (B, H, W, H, W))
579
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
580
+ h = self.NIN_3(h)
581
+ return x + h
582
+
583
+
584
+ class Upsample(nn.Module):
585
+ def __init__(self, channels, with_conv=False):
586
+ super().__init__()
587
+ if with_conv:
588
+ self.Conv_0 = ddpm_conv3x3(channels, channels)
589
+ self.with_conv = with_conv
590
+
591
+ def forward(self, x):
592
+ B, C, H, W = x.shape
593
+ h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
594
+ if self.with_conv:
595
+ h = self.Conv_0(h)
596
+ return h
597
+
598
+
599
+ class Downsample(nn.Module):
600
+ def __init__(self, channels, with_conv=False):
601
+ super().__init__()
602
+ if with_conv:
603
+ self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
604
+ self.with_conv = with_conv
605
+
606
+ def forward(self, x):
607
+ B, C, H, W = x.shape
608
+ # Emulate 'SAME' padding
609
+ if self.with_conv:
610
+ x = F.pad(x, (0, 1, 0, 1))
611
+ x = self.Conv_0(x)
612
+ else:
613
+ x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
614
+
615
+ assert x.shape == (B, C, H // 2, W // 2)
616
+ return x
617
+
618
+
619
+ class ResnetBlockDDPM(nn.Module):
620
+ """The ResNet Blocks used in DDPM."""
621
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
622
+ super().__init__()
623
+ if out_ch is None:
624
+ out_ch = in_ch
625
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
626
+ self.act = act
627
+ self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
628
+ if temb_dim is not None:
629
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
630
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
631
+ nn.init.zeros_(self.Dense_0.bias)
632
+
633
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
634
+ self.Dropout_0 = nn.Dropout(dropout)
635
+ self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
636
+ if in_ch != out_ch:
637
+ if conv_shortcut:
638
+ self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
639
+ else:
640
+ self.NIN_0 = NIN(in_ch, out_ch)
641
+ self.out_ch = out_ch
642
+ self.in_ch = in_ch
643
+ self.conv_shortcut = conv_shortcut
644
+
645
+ def forward(self, x, temb=None):
646
+ B, C, H, W = x.shape
647
+ assert C == self.in_ch
648
+ out_ch = self.out_ch if self.out_ch else self.in_ch
649
+ h = self.act(self.GroupNorm_0(x))
650
+ h = self.Conv_0(h)
651
+ # Add bias to each feature map conditioned on the time embedding
652
+ if temb is not None:
653
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
654
+ h = self.act(self.GroupNorm_1(h))
655
+ h = self.Dropout_0(h)
656
+ h = self.Conv_1(h)
657
+ if C != out_ch:
658
+ if self.conv_shortcut:
659
+ x = self.Conv_2(x)
660
+ else:
661
+ x = self.NIN_0(x)
662
+ return x + h
sgmse/backbones/ncsnpp_utils/layerspp.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Layers for defining NCSN++.
18
+ """
19
+ from . import layers
20
+ from . import up_or_down_sampling
21
+ import torch.nn as nn
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+
26
+ conv1x1 = layers.ddpm_conv1x1
27
+ conv3x3 = layers.ddpm_conv3x3
28
+ NIN = layers.NIN
29
+ default_init = layers.default_init
30
+
31
+
32
+ class GaussianFourierProjection(nn.Module):
33
+ """Gaussian Fourier embeddings for noise levels."""
34
+
35
+ def __init__(self, embedding_size=256, scale=1.0):
36
+ super().__init__()
37
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
38
+
39
+ def forward(self, x):
40
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
41
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
42
+
43
+
44
+ class Combine(nn.Module):
45
+ """Combine information from skip connections."""
46
+
47
+ def __init__(self, dim1, dim2, method='cat'):
48
+ super().__init__()
49
+ self.Conv_0 = conv1x1(dim1, dim2)
50
+ self.method = method
51
+
52
+ def forward(self, x, y):
53
+ h = self.Conv_0(x)
54
+ if self.method == 'cat':
55
+ return torch.cat([h, y], dim=1)
56
+ elif self.method == 'sum':
57
+ return h + y
58
+ else:
59
+ raise ValueError(f'Method {self.method} not recognized.')
60
+
61
+
62
+ class AttnBlockpp(nn.Module):
63
+ """Channel-wise self-attention block. Modified from DDPM."""
64
+
65
+ def __init__(self, channels, skip_rescale=False, init_scale=0.):
66
+ super().__init__()
67
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
68
+ eps=1e-6)
69
+ self.NIN_0 = NIN(channels, channels)
70
+ self.NIN_1 = NIN(channels, channels)
71
+ self.NIN_2 = NIN(channels, channels)
72
+ self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
73
+ self.skip_rescale = skip_rescale
74
+
75
+ def forward(self, x):
76
+ B, C, H, W = x.shape
77
+ h = self.GroupNorm_0(x)
78
+ q = self.NIN_0(h)
79
+ k = self.NIN_1(h)
80
+ v = self.NIN_2(h)
81
+
82
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
83
+ w = torch.reshape(w, (B, H, W, H * W))
84
+ w = F.softmax(w, dim=-1)
85
+ w = torch.reshape(w, (B, H, W, H, W))
86
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
87
+ h = self.NIN_3(h)
88
+ if not self.skip_rescale:
89
+ return x + h
90
+ else:
91
+ return (x + h) / np.sqrt(2.)
92
+
93
+
94
+ class Upsample(nn.Module):
95
+ def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
96
+ fir_kernel=(1, 3, 3, 1)):
97
+ super().__init__()
98
+ out_ch = out_ch if out_ch else in_ch
99
+ if not fir:
100
+ if with_conv:
101
+ self.Conv_0 = conv3x3(in_ch, out_ch)
102
+ else:
103
+ if with_conv:
104
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
105
+ kernel=3, up=True,
106
+ resample_kernel=fir_kernel,
107
+ use_bias=True,
108
+ kernel_init=default_init())
109
+ self.fir = fir
110
+ self.with_conv = with_conv
111
+ self.fir_kernel = fir_kernel
112
+ self.out_ch = out_ch
113
+
114
+ def forward(self, x):
115
+ B, C, H, W = x.shape
116
+ if not self.fir:
117
+ h = F.interpolate(x, (H * 2, W * 2), 'nearest')
118
+ if self.with_conv:
119
+ h = self.Conv_0(h)
120
+ else:
121
+ if not self.with_conv:
122
+ h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
123
+ else:
124
+ h = self.Conv2d_0(x)
125
+
126
+ return h
127
+
128
+
129
+ class Downsample(nn.Module):
130
+ def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
131
+ fir_kernel=(1, 3, 3, 1)):
132
+ super().__init__()
133
+ out_ch = out_ch if out_ch else in_ch
134
+ if not fir:
135
+ if with_conv:
136
+ self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
137
+ else:
138
+ if with_conv:
139
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
140
+ kernel=3, down=True,
141
+ resample_kernel=fir_kernel,
142
+ use_bias=True,
143
+ kernel_init=default_init())
144
+ self.fir = fir
145
+ self.fir_kernel = fir_kernel
146
+ self.with_conv = with_conv
147
+ self.out_ch = out_ch
148
+
149
+ def forward(self, x):
150
+ B, C, H, W = x.shape
151
+ if not self.fir:
152
+ if self.with_conv:
153
+ x = F.pad(x, (0, 1, 0, 1))
154
+ x = self.Conv_0(x)
155
+ else:
156
+ x = F.avg_pool2d(x, 2, stride=2)
157
+ else:
158
+ if not self.with_conv:
159
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
160
+ else:
161
+ x = self.Conv2d_0(x)
162
+
163
+ return x
164
+
165
+
166
+ class ResnetBlockDDPMpp(nn.Module):
167
+ """ResBlock adapted from DDPM."""
168
+
169
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
170
+ dropout=0.1, skip_rescale=False, init_scale=0.):
171
+ super().__init__()
172
+ out_ch = out_ch if out_ch else in_ch
173
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
174
+ self.Conv_0 = conv3x3(in_ch, out_ch)
175
+ if temb_dim is not None:
176
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
177
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
178
+ nn.init.zeros_(self.Dense_0.bias)
179
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
180
+ self.Dropout_0 = nn.Dropout(dropout)
181
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
182
+ if in_ch != out_ch:
183
+ if conv_shortcut:
184
+ self.Conv_2 = conv3x3(in_ch, out_ch)
185
+ else:
186
+ self.NIN_0 = NIN(in_ch, out_ch)
187
+
188
+ self.skip_rescale = skip_rescale
189
+ self.act = act
190
+ self.out_ch = out_ch
191
+ self.conv_shortcut = conv_shortcut
192
+
193
+ def forward(self, x, temb=None):
194
+ h = self.act(self.GroupNorm_0(x))
195
+ h = self.Conv_0(h)
196
+ if temb is not None:
197
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
198
+ h = self.act(self.GroupNorm_1(h))
199
+ h = self.Dropout_0(h)
200
+ h = self.Conv_1(h)
201
+ if x.shape[1] != self.out_ch:
202
+ if self.conv_shortcut:
203
+ x = self.Conv_2(x)
204
+ else:
205
+ x = self.NIN_0(x)
206
+ if not self.skip_rescale:
207
+ return x + h
208
+ else:
209
+ return (x + h) / np.sqrt(2.)
210
+
211
+
212
+ class ResnetBlockBigGANpp(nn.Module):
213
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
214
+ dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
215
+ skip_rescale=True, init_scale=0.):
216
+ super().__init__()
217
+
218
+ out_ch = out_ch if out_ch else in_ch
219
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
220
+ self.up = up
221
+ self.down = down
222
+ self.fir = fir
223
+ self.fir_kernel = fir_kernel
224
+
225
+ self.Conv_0 = conv3x3(in_ch, out_ch)
226
+ if temb_dim is not None:
227
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
228
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
229
+ nn.init.zeros_(self.Dense_0.bias)
230
+
231
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
232
+ self.Dropout_0 = nn.Dropout(dropout)
233
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
234
+ if in_ch != out_ch or up or down:
235
+ self.Conv_2 = conv1x1(in_ch, out_ch)
236
+
237
+ self.skip_rescale = skip_rescale
238
+ self.act = act
239
+ self.in_ch = in_ch
240
+ self.out_ch = out_ch
241
+
242
+ def forward(self, x, temb=None):
243
+ h = self.act(self.GroupNorm_0(x))
244
+
245
+ if self.up:
246
+ if self.fir:
247
+ h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
248
+ x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
249
+ else:
250
+ h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
251
+ x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
252
+ elif self.down:
253
+ if self.fir:
254
+ h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
255
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
256
+ else:
257
+ h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
258
+ x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
259
+
260
+ h = self.Conv_0(h)
261
+ # Add bias to each feature map conditioned on the time embedding
262
+ if temb is not None:
263
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
264
+ h = self.act(self.GroupNorm_1(h))
265
+ h = self.Dropout_0(h)
266
+ h = self.Conv_1(h)
267
+
268
+ if self.in_ch != self.out_ch or self.up or self.down:
269
+ x = self.Conv_2(x)
270
+
271
+ if not self.skip_rescale:
272
+ return x + h
273
+ else:
274
+ return (x + h) / np.sqrt(2.)
sgmse/backbones/ncsnpp_utils/normalization.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Normalization layers."""
17
+ import torch.nn as nn
18
+ import torch
19
+ import functools
20
+
21
+
22
+ def get_normalization(config, conditional=False):
23
+ """Obtain normalization modules from the config file."""
24
+ norm = config.model.normalization
25
+ if conditional:
26
+ if norm == 'InstanceNorm++':
27
+ return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
28
+ else:
29
+ raise NotImplementedError(f'{norm} not implemented yet.')
30
+ else:
31
+ if norm == 'InstanceNorm':
32
+ return nn.InstanceNorm2d
33
+ elif norm == 'InstanceNorm++':
34
+ return InstanceNorm2dPlus
35
+ elif norm == 'VarianceNorm':
36
+ return VarianceNorm2d
37
+ elif norm == 'GroupNorm':
38
+ return nn.GroupNorm
39
+ else:
40
+ raise ValueError('Unknown normalization: %s' % norm)
41
+
42
+
43
+ class ConditionalBatchNorm2d(nn.Module):
44
+ def __init__(self, num_features, num_classes, bias=True):
45
+ super().__init__()
46
+ self.num_features = num_features
47
+ self.bias = bias
48
+ self.bn = nn.BatchNorm2d(num_features, affine=False)
49
+ if self.bias:
50
+ self.embed = nn.Embedding(num_classes, num_features * 2)
51
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
52
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
53
+ else:
54
+ self.embed = nn.Embedding(num_classes, num_features)
55
+ self.embed.weight.data.uniform_()
56
+
57
+ def forward(self, x, y):
58
+ out = self.bn(x)
59
+ if self.bias:
60
+ gamma, beta = self.embed(y).chunk(2, dim=1)
61
+ out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
62
+ else:
63
+ gamma = self.embed(y)
64
+ out = gamma.view(-1, self.num_features, 1, 1) * out
65
+ return out
66
+
67
+
68
+ class ConditionalInstanceNorm2d(nn.Module):
69
+ def __init__(self, num_features, num_classes, bias=True):
70
+ super().__init__()
71
+ self.num_features = num_features
72
+ self.bias = bias
73
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
74
+ if bias:
75
+ self.embed = nn.Embedding(num_classes, num_features * 2)
76
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
77
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
78
+ else:
79
+ self.embed = nn.Embedding(num_classes, num_features)
80
+ self.embed.weight.data.uniform_()
81
+
82
+ def forward(self, x, y):
83
+ h = self.instance_norm(x)
84
+ if self.bias:
85
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
86
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
87
+ else:
88
+ gamma = self.embed(y)
89
+ out = gamma.view(-1, self.num_features, 1, 1) * h
90
+ return out
91
+
92
+
93
+ class ConditionalVarianceNorm2d(nn.Module):
94
+ def __init__(self, num_features, num_classes, bias=False):
95
+ super().__init__()
96
+ self.num_features = num_features
97
+ self.bias = bias
98
+ self.embed = nn.Embedding(num_classes, num_features)
99
+ self.embed.weight.data.normal_(1, 0.02)
100
+
101
+ def forward(self, x, y):
102
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
103
+ h = x / torch.sqrt(vars + 1e-5)
104
+
105
+ gamma = self.embed(y)
106
+ out = gamma.view(-1, self.num_features, 1, 1) * h
107
+ return out
108
+
109
+
110
+ class VarianceNorm2d(nn.Module):
111
+ def __init__(self, num_features, bias=False):
112
+ super().__init__()
113
+ self.num_features = num_features
114
+ self.bias = bias
115
+ self.alpha = nn.Parameter(torch.zeros(num_features))
116
+ self.alpha.data.normal_(1, 0.02)
117
+
118
+ def forward(self, x):
119
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
120
+ h = x / torch.sqrt(vars + 1e-5)
121
+
122
+ out = self.alpha.view(-1, self.num_features, 1, 1) * h
123
+ return out
124
+
125
+
126
+ class ConditionalNoneNorm2d(nn.Module):
127
+ def __init__(self, num_features, num_classes, bias=True):
128
+ super().__init__()
129
+ self.num_features = num_features
130
+ self.bias = bias
131
+ if bias:
132
+ self.embed = nn.Embedding(num_classes, num_features * 2)
133
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
134
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
135
+ else:
136
+ self.embed = nn.Embedding(num_classes, num_features)
137
+ self.embed.weight.data.uniform_()
138
+
139
+ def forward(self, x, y):
140
+ if self.bias:
141
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
142
+ out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
143
+ else:
144
+ gamma = self.embed(y)
145
+ out = gamma.view(-1, self.num_features, 1, 1) * x
146
+ return out
147
+
148
+
149
+ class NoneNorm2d(nn.Module):
150
+ def __init__(self, num_features, bias=True):
151
+ super().__init__()
152
+
153
+ def forward(self, x):
154
+ return x
155
+
156
+
157
+ class InstanceNorm2dPlus(nn.Module):
158
+ def __init__(self, num_features, bias=True):
159
+ super().__init__()
160
+ self.num_features = num_features
161
+ self.bias = bias
162
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
163
+ self.alpha = nn.Parameter(torch.zeros(num_features))
164
+ self.gamma = nn.Parameter(torch.zeros(num_features))
165
+ self.alpha.data.normal_(1, 0.02)
166
+ self.gamma.data.normal_(1, 0.02)
167
+ if bias:
168
+ self.beta = nn.Parameter(torch.zeros(num_features))
169
+
170
+ def forward(self, x):
171
+ means = torch.mean(x, dim=(2, 3))
172
+ m = torch.mean(means, dim=-1, keepdim=True)
173
+ v = torch.var(means, dim=-1, keepdim=True)
174
+ means = (means - m) / (torch.sqrt(v + 1e-5))
175
+ h = self.instance_norm(x)
176
+
177
+ if self.bias:
178
+ h = h + means[..., None, None] * self.alpha[..., None, None]
179
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
180
+ else:
181
+ h = h + means[..., None, None] * self.alpha[..., None, None]
182
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h
183
+ return out
184
+
185
+
186
+ class ConditionalInstanceNorm2dPlus(nn.Module):
187
+ def __init__(self, num_features, num_classes, bias=True):
188
+ super().__init__()
189
+ self.num_features = num_features
190
+ self.bias = bias
191
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
192
+ if bias:
193
+ self.embed = nn.Embedding(num_classes, num_features * 3)
194
+ self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
195
+ self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
196
+ else:
197
+ self.embed = nn.Embedding(num_classes, 2 * num_features)
198
+ self.embed.weight.data.normal_(1, 0.02)
199
+
200
+ def forward(self, x, y):
201
+ means = torch.mean(x, dim=(2, 3))
202
+ m = torch.mean(means, dim=-1, keepdim=True)
203
+ v = torch.var(means, dim=-1, keepdim=True)
204
+ means = (means - m) / (torch.sqrt(v + 1e-5))
205
+ h = self.instance_norm(x)
206
+
207
+ if self.bias:
208
+ gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
209
+ h = h + means[..., None, None] * alpha[..., None, None]
210
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
211
+ else:
212
+ gamma, alpha = self.embed(y).chunk(2, dim=-1)
213
+ h = h + means[..., None, None] * alpha[..., None, None]
214
+ out = gamma.view(-1, self.num_features, 1, 1) * h
215
+ return out
sgmse/backbones/ncsnpp_utils/op/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .upfirdn2d import upfirdn2d
sgmse/backbones/ncsnpp_utils/op/fused_act.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output, empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ grad_bias = grad_input.sum(dim).detach()
39
+
40
+ return grad_input, grad_bias
41
+
42
+ @staticmethod
43
+ def backward(ctx, gradgrad_input, gradgrad_bias):
44
+ out, = ctx.saved_tensors
45
+ gradgrad_out = fused.fused_bias_act(
46
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
47
+ )
48
+
49
+ return gradgrad_out, None, None, None
50
+
51
+
52
+ class FusedLeakyReLUFunction(Function):
53
+ @staticmethod
54
+ def forward(ctx, input, bias, negative_slope, scale):
55
+ empty = input.new_empty(0)
56
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57
+ ctx.save_for_backward(out)
58
+ ctx.negative_slope = negative_slope
59
+ ctx.scale = scale
60
+
61
+ return out
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad_output):
65
+ out, = ctx.saved_tensors
66
+
67
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
68
+ grad_output, out, ctx.negative_slope, ctx.scale
69
+ )
70
+
71
+ return grad_input, grad_bias, None, None
72
+
73
+
74
+ class FusedLeakyReLU(nn.Module):
75
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
76
+ super().__init__()
77
+
78
+ self.bias = nn.Parameter(torch.zeros(channel))
79
+ self.negative_slope = negative_slope
80
+ self.scale = scale
81
+
82
+ def forward(self, input):
83
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
84
+
85
+
86
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
87
+ if input.device.type == "cpu":
88
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
89
+ return (
90
+ F.leaky_relu(
91
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
92
+ )
93
+ * scale
94
+ )
95
+
96
+ else:
97
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
sgmse/backbones/ncsnpp_utils/op/upfirdn2d.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+ if torch.cuda.is_available():
12
+ upfirdn2d_op = load(
13
+ "upfirdn2d",
14
+ sources=[
15
+ os.path.join(module_path, "upfirdn2d.cpp"),
16
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
17
+ ],
18
+ )
19
+ else:
20
+ upfirdn2d_op = None
21
+
22
+ class UpFirDn2dBackward(Function):
23
+ @staticmethod
24
+ def forward(
25
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
26
+ ):
27
+
28
+ up_x, up_y = up
29
+ down_x, down_y = down
30
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
31
+
32
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
33
+
34
+ grad_input = upfirdn2d_op.upfirdn2d(
35
+ grad_output,
36
+ grad_kernel,
37
+ down_x,
38
+ down_y,
39
+ up_x,
40
+ up_y,
41
+ g_pad_x0,
42
+ g_pad_x1,
43
+ g_pad_y0,
44
+ g_pad_y1,
45
+ )
46
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
47
+
48
+ ctx.save_for_backward(kernel)
49
+
50
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
51
+
52
+ ctx.up_x = up_x
53
+ ctx.up_y = up_y
54
+ ctx.down_x = down_x
55
+ ctx.down_y = down_y
56
+ ctx.pad_x0 = pad_x0
57
+ ctx.pad_x1 = pad_x1
58
+ ctx.pad_y0 = pad_y0
59
+ ctx.pad_y1 = pad_y1
60
+ ctx.in_size = in_size
61
+ ctx.out_size = out_size
62
+
63
+ return grad_input
64
+
65
+ @staticmethod
66
+ def backward(ctx, gradgrad_input):
67
+ kernel, = ctx.saved_tensors
68
+
69
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
70
+
71
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
72
+ gradgrad_input,
73
+ kernel,
74
+ ctx.up_x,
75
+ ctx.up_y,
76
+ ctx.down_x,
77
+ ctx.down_y,
78
+ ctx.pad_x0,
79
+ ctx.pad_x1,
80
+ ctx.pad_y0,
81
+ ctx.pad_y1,
82
+ )
83
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
84
+ gradgrad_out = gradgrad_out.view(
85
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
86
+ )
87
+
88
+ return gradgrad_out, None, None, None, None, None, None, None, None
89
+
90
+
91
+ class UpFirDn2d(Function):
92
+ @staticmethod
93
+ def forward(ctx, input, kernel, up, down, pad):
94
+ up_x, up_y = up
95
+ down_x, down_y = down
96
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
97
+
98
+ kernel_h, kernel_w = kernel.shape
99
+ batch, channel, in_h, in_w = input.shape
100
+ ctx.in_size = input.shape
101
+
102
+ input = input.reshape(-1, in_h, in_w, 1)
103
+
104
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
105
+
106
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
107
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
108
+ ctx.out_size = (out_h, out_w)
109
+
110
+ ctx.up = (up_x, up_y)
111
+ ctx.down = (down_x, down_y)
112
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
113
+
114
+ g_pad_x0 = kernel_w - pad_x0 - 1
115
+ g_pad_y0 = kernel_h - pad_y0 - 1
116
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
117
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
118
+
119
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
120
+
121
+ out = upfirdn2d_op.upfirdn2d(
122
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
123
+ )
124
+ # out = out.view(major, out_h, out_w, minor)
125
+ out = out.view(-1, channel, out_h, out_w)
126
+
127
+ return out
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ kernel, grad_kernel = ctx.saved_tensors
132
+
133
+ grad_input = UpFirDn2dBackward.apply(
134
+ grad_output,
135
+ kernel,
136
+ grad_kernel,
137
+ ctx.up,
138
+ ctx.down,
139
+ ctx.pad,
140
+ ctx.g_pad,
141
+ ctx.in_size,
142
+ ctx.out_size,
143
+ )
144
+
145
+ return grad_input, None, None, None, None
146
+
147
+
148
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
149
+ if input.device.type == "cpu":
150
+ out = upfirdn2d_native(
151
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
152
+ )
153
+
154
+ else:
155
+ out = UpFirDn2d.apply(
156
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
157
+ )
158
+
159
+ return out
160
+
161
+
162
+ def upfirdn2d_native(
163
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
164
+ ):
165
+ _, channel, in_h, in_w = input.shape
166
+ input = input.reshape(-1, in_h, in_w, 1)
167
+
168
+ _, in_h, in_w, minor = input.shape
169
+ kernel_h, kernel_w = kernel.shape
170
+
171
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
172
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
173
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
174
+
175
+ out = F.pad(
176
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
177
+ )
178
+ out = out[
179
+ :,
180
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
181
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
182
+ :,
183
+ ]
184
+
185
+ out = out.permute(0, 3, 1, 2)
186
+ out = out.reshape(
187
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
188
+ )
189
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
190
+ out = F.conv2d(out, w)
191
+ out = out.reshape(
192
+ -1,
193
+ minor,
194
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
195
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
196
+ )
197
+ out = out.permute(0, 2, 3, 1)
198
+ out = out[:, ::down_y, ::down_x, :]
199
+
200
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
201
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
202
+
203
+ return out.view(-1, channel, out_h, out_w)
sgmse/backbones/ncsnpp_utils/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }
sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Layers used for up-sampling or down-sampling images.
2
+
3
+ Many functions are ported from https://github.com/NVlabs/stylegan2.
4
+ """
5
+
6
+ import torch.nn as nn
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from .op import upfirdn2d
11
+
12
+
13
+ # Function ported from StyleGAN2
14
+ def get_weight(module,
15
+ shape,
16
+ weight_var='weight',
17
+ kernel_init=None):
18
+ """Get/create weight tensor for a convolution or fully-connected layer."""
19
+
20
+ return module.param(weight_var, kernel_init, shape)
21
+
22
+
23
+ class Conv2d(nn.Module):
24
+ """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
25
+
26
+ def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
27
+ resample_kernel=(1, 3, 3, 1),
28
+ use_bias=True,
29
+ kernel_init=None):
30
+ super().__init__()
31
+ assert not (up and down)
32
+ assert kernel >= 1 and kernel % 2 == 1
33
+ self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
34
+ if kernel_init is not None:
35
+ self.weight.data = kernel_init(self.weight.data.shape)
36
+ if use_bias:
37
+ self.bias = nn.Parameter(torch.zeros(out_ch))
38
+
39
+ self.up = up
40
+ self.down = down
41
+ self.resample_kernel = resample_kernel
42
+ self.kernel = kernel
43
+ self.use_bias = use_bias
44
+
45
+ def forward(self, x):
46
+ if self.up:
47
+ x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
48
+ elif self.down:
49
+ x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
50
+ else:
51
+ x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
52
+
53
+ if self.use_bias:
54
+ x = x + self.bias.reshape(1, -1, 1, 1)
55
+
56
+ return x
57
+
58
+
59
+ def naive_upsample_2d(x, factor=2):
60
+ _N, C, H, W = x.shape
61
+ x = torch.reshape(x, (-1, C, H, 1, W, 1))
62
+ x = x.repeat(1, 1, 1, factor, 1, factor)
63
+ return torch.reshape(x, (-1, C, H * factor, W * factor))
64
+
65
+
66
+ def naive_downsample_2d(x, factor=2):
67
+ _N, C, H, W = x.shape
68
+ x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
69
+ return torch.mean(x, dim=(3, 5))
70
+
71
+
72
+ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
73
+ """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
74
+
75
+ Padding is performed only once at the beginning, not between the
76
+ operations.
77
+ The fused op is considerably more efficient than performing the same
78
+ calculation
79
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
80
+ Args:
81
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
82
+ C]`.
83
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
84
+ outChannels]`. Grouped convolution can be performed by `inChannels =
85
+ x.shape[0] // numGroups`.
86
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
87
+ (separable). The default is `[1] * factor`, which corresponds to
88
+ nearest-neighbor upsampling.
89
+ factor: Integer upsampling factor (default: 2).
90
+ gain: Scaling factor for signal magnitude (default: 1.0).
91
+
92
+ Returns:
93
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
94
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
95
+ """
96
+
97
+ assert isinstance(factor, int) and factor >= 1
98
+
99
+ # Check weight shape.
100
+ assert len(w.shape) == 4
101
+ convH = w.shape[2]
102
+ convW = w.shape[3]
103
+ inC = w.shape[1]
104
+ outC = w.shape[0]
105
+
106
+ assert convW == convH
107
+
108
+ # Setup filter kernel.
109
+ if k is None:
110
+ k = [1] * factor
111
+ k = _setup_kernel(k) * (gain * (factor ** 2))
112
+ p = (k.shape[0] - factor) - (convW - 1)
113
+
114
+ stride = (factor, factor)
115
+
116
+ # Determine data dimensions.
117
+ stride = [1, 1, factor, factor]
118
+ output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
119
+ output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
120
+ output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
121
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
122
+ num_groups = _shape(x, 1) // inC
123
+
124
+ # Transpose weights.
125
+ w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
126
+ w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
127
+ w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
128
+
129
+ x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
130
+ ## Original TF code.
131
+ # x = tf.nn.conv2d_transpose(
132
+ # x,
133
+ # w,
134
+ # output_shape=output_shape,
135
+ # strides=stride,
136
+ # padding='VALID',
137
+ # data_format=data_format)
138
+ ## JAX equivalent
139
+
140
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
141
+ pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
142
+
143
+
144
+ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
145
+ """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
146
+
147
+ Padding is performed only once at the beginning, not between the operations.
148
+ The fused op is considerably more efficient than performing the same
149
+ calculation
150
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
151
+ Args:
152
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
153
+ C]`.
154
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
155
+ outChannels]`. Grouped convolution can be performed by `inChannels =
156
+ x.shape[0] // numGroups`.
157
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
158
+ (separable). The default is `[1] * factor`, which corresponds to
159
+ average pooling.
160
+ factor: Integer downsampling factor (default: 2).
161
+ gain: Scaling factor for signal magnitude (default: 1.0).
162
+
163
+ Returns:
164
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
165
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
166
+ """
167
+
168
+ assert isinstance(factor, int) and factor >= 1
169
+ _outC, _inC, convH, convW = w.shape
170
+ assert convW == convH
171
+ if k is None:
172
+ k = [1] * factor
173
+ k = _setup_kernel(k) * gain
174
+ p = (k.shape[0] - factor) + (convW - 1)
175
+ s = [factor, factor]
176
+ x = upfirdn2d(x, torch.tensor(k, device=x.device),
177
+ pad=((p + 1) // 2, p // 2))
178
+ return F.conv2d(x, w, stride=s, padding=0)
179
+
180
+
181
+ def _setup_kernel(k):
182
+ k = np.asarray(k, dtype=np.float32)
183
+ if k.ndim == 1:
184
+ k = np.outer(k, k)
185
+ k /= np.sum(k)
186
+ assert k.ndim == 2
187
+ assert k.shape[0] == k.shape[1]
188
+ return k
189
+
190
+
191
+ def _shape(x, dim):
192
+ return x.shape[dim]
193
+
194
+
195
+ def upsample_2d(x, k=None, factor=2, gain=1):
196
+ r"""Upsample a batch of 2D images with the given filter.
197
+
198
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
199
+ and upsamples each image with the given filter. The filter is normalized so
200
+ that
201
+ if the input pixels are constant, they will be scaled by the specified
202
+ `gain`.
203
+ Pixels outside the image are assumed to be zero, and the filter is padded
204
+ with
205
+ zeros so that its shape is a multiple of the upsampling factor.
206
+ Args:
207
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
208
+ C]`.
209
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
210
+ (separable). The default is `[1] * factor`, which corresponds to
211
+ nearest-neighbor upsampling.
212
+ factor: Integer upsampling factor (default: 2).
213
+ gain: Scaling factor for signal magnitude (default: 1.0).
214
+
215
+ Returns:
216
+ Tensor of the shape `[N, C, H * factor, W * factor]`
217
+ """
218
+ assert isinstance(factor, int) and factor >= 1
219
+ if k is None:
220
+ k = [1] * factor
221
+ k = _setup_kernel(k) * (gain * (factor ** 2))
222
+ p = k.shape[0] - factor
223
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
224
+ up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
225
+
226
+
227
+ def downsample_2d(x, k=None, factor=2, gain=1):
228
+ r"""Downsample a batch of 2D images with the given filter.
229
+
230
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
231
+ and downsamples each image with the given filter. The filter is normalized
232
+ so that
233
+ if the input pixels are constant, they will be scaled by the specified
234
+ `gain`.
235
+ Pixels outside the image are assumed to be zero, and the filter is padded
236
+ with
237
+ zeros so that its shape is a multiple of the downsampling factor.
238
+ Args:
239
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
240
+ C]`.
241
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
242
+ (separable). The default is `[1] * factor`, which corresponds to
243
+ average pooling.
244
+ factor: Integer downsampling factor (default: 2).
245
+ gain: Scaling factor for signal magnitude (default: 1.0).
246
+
247
+ Returns:
248
+ Tensor of the shape `[N, C, H // factor, W // factor]`
249
+ """
250
+
251
+ assert isinstance(factor, int) and factor >= 1
252
+ if k is None:
253
+ k = [1] * factor
254
+ k = _setup_kernel(k) * gain
255
+ p = k.shape[0] - factor
256
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
257
+ down=factor, pad=((p + 1) // 2, p // 2))
sgmse/backbones/ncsnpp_utils/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """All functions and modules related to model definition.
17
+ """
18
+
19
+ import torch
20
+
21
+ import numpy as np
22
+ from ...sdes import OUVESDE, OUVPSDE
23
+
24
+
25
+ _MODELS = {}
26
+
27
+
28
+ def register_model(cls=None, *, name=None):
29
+ """A decorator for registering model classes."""
30
+
31
+ def _register(cls):
32
+ if name is None:
33
+ local_name = cls.__name__
34
+ else:
35
+ local_name = name
36
+ if local_name in _MODELS:
37
+ raise ValueError(f'Already registered model with name: {local_name}')
38
+ _MODELS[local_name] = cls
39
+ return cls
40
+
41
+ if cls is None:
42
+ return _register
43
+ else:
44
+ return _register(cls)
45
+
46
+
47
+ def get_model(name):
48
+ return _MODELS[name]
49
+
50
+
51
+ def get_sigmas(sigma_min, sigma_max, num_scales):
52
+ """Get sigmas --- the set of noise levels for SMLD from config files.
53
+ Args:
54
+ config: A ConfigDict object parsed from the config file
55
+ Returns:
56
+ sigmas: a jax numpy arrary of noise levels
57
+ """
58
+ sigmas = np.exp(
59
+ np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
60
+
61
+ return sigmas
62
+
63
+
64
+ def get_ddpm_params(config):
65
+ """Get betas and alphas --- parameters used in the original DDPM paper."""
66
+ num_diffusion_timesteps = 1000
67
+ # parameters need to be adapted if number of time steps differs from 1000
68
+ beta_start = config.model.beta_min / config.model.num_scales
69
+ beta_end = config.model.beta_max / config.model.num_scales
70
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
71
+
72
+ alphas = 1. - betas
73
+ alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
75
+ sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
76
+
77
+ return {
78
+ 'betas': betas,
79
+ 'alphas': alphas,
80
+ 'alphas_cumprod': alphas_cumprod,
81
+ 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
82
+ 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
83
+ 'beta_min': beta_start * (num_diffusion_timesteps - 1),
84
+ 'beta_max': beta_end * (num_diffusion_timesteps - 1),
85
+ 'num_diffusion_timesteps': num_diffusion_timesteps
86
+ }
87
+
88
+
89
+ def create_model(config):
90
+ """Create the score model."""
91
+ model_name = config.model.name
92
+ score_model = get_model(model_name)(config)
93
+ score_model = score_model.to(config.device)
94
+ score_model = torch.nn.DataParallel(score_model)
95
+ return score_model
96
+
97
+
98
+ def get_model_fn(model, train=False):
99
+ """Create a function to give the output of the score-based model.
100
+
101
+ Args:
102
+ model: The score model.
103
+ train: `True` for training and `False` for evaluation.
104
+
105
+ Returns:
106
+ A model function.
107
+ """
108
+
109
+ def model_fn(x, labels):
110
+ """Compute the output of the score-based model.
111
+
112
+ Args:
113
+ x: A mini-batch of input data.
114
+ labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
115
+ for different models.
116
+
117
+ Returns:
118
+ A tuple of (model output, new mutable states)
119
+ """
120
+ if not train:
121
+ model.eval()
122
+ return model(x, labels)
123
+ else:
124
+ model.train()
125
+ return model(x, labels)
126
+
127
+ return model_fn
128
+
129
+
130
+ def get_score_fn(sde, model, train=False, continuous=False):
131
+ """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
132
+
133
+ Args:
134
+ sde: An `sde_lib.SDE` object that represents the forward SDE.
135
+ model: A score model.
136
+ train: `True` for training and `False` for evaluation.
137
+ continuous: If `True`, the score-based model is expected to directly take continuous time steps.
138
+
139
+ Returns:
140
+ A score function.
141
+ """
142
+ model_fn = get_model_fn(model, train=train)
143
+
144
+ if isinstance(sde, OUVPSDE):
145
+ def score_fn(x, t):
146
+ # Scale neural network output by standard deviation and flip sign
147
+ if continuous:
148
+ # For VP-trained models, t=0 corresponds to the lowest noise level
149
+ # The maximum value of time embedding is assumed to 999 for
150
+ # continuously-trained models.
151
+ labels = t * 999
152
+ score = model_fn(x, labels)
153
+ std = sde.marginal_prob(torch.zeros_like(x), t)[1]
154
+ else:
155
+ # For VP-trained models, t=0 corresponds to the lowest noise level
156
+ labels = t * (sde.N - 1)
157
+ score = model_fn(x, labels)
158
+ std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
159
+
160
+ score = -score / std[:, None, None, None]
161
+ return score
162
+
163
+ elif isinstance(sde, OUVESDE):
164
+ def score_fn(x, t):
165
+ if continuous:
166
+ labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
167
+ else:
168
+ # For VE-trained models, t=0 corresponds to the highest noise level
169
+ labels = sde.T - t
170
+ labels *= sde.N - 1
171
+ labels = torch.round(labels).long()
172
+
173
+ score = model_fn(x, labels)
174
+ return score
175
+
176
+ else:
177
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
178
+
179
+ return score_fn
180
+
181
+
182
+ def to_flattened_numpy(x):
183
+ """Flatten a torch tensor `x` and convert it to numpy."""
184
+ return x.detach().cpu().numpy().reshape((-1,))
185
+
186
+
187
+ def from_flattened_numpy(x, shape):
188
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
189
+ return torch.from_numpy(x.reshape(shape))
sgmse/backbones/shared.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sgmse.util.registry import Registry
8
+
9
+
10
+ BackboneRegistry = Registry("Backbone")
11
+
12
+
13
+ class GaussianFourierProjection(nn.Module):
14
+ """Gaussian random features for encoding time steps."""
15
+
16
+ def __init__(self, embed_dim, scale=16, complex_valued=False):
17
+ super().__init__()
18
+ self.complex_valued = complex_valued
19
+ if not complex_valued:
20
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
21
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
22
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
23
+ # and this halving is not necessary.
24
+ embed_dim = embed_dim // 2
25
+ # Randomly sample weights during initialization. These weights are fixed
26
+ # during optimization and are not trainable.
27
+ self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
28
+
29
+ def forward(self, t):
30
+ t_proj = t[:, None] * self.W[None, :] * 2*np.pi
31
+ if self.complex_valued:
32
+ return torch.exp(1j * t_proj)
33
+ else:
34
+ return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
35
+
36
+
37
+ class DiffusionStepEmbedding(nn.Module):
38
+ """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
39
+
40
+ def __init__(self, embed_dim, complex_valued=False):
41
+ super().__init__()
42
+ self.complex_valued = complex_valued
43
+ if not complex_valued:
44
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
45
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
46
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
47
+ # and this halving is not necessary.
48
+ embed_dim = embed_dim // 2
49
+ self.embed_dim = embed_dim
50
+
51
+ def forward(self, t):
52
+ fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
53
+ inner = t[:, None] * fac[None, :]
54
+ if self.complex_valued:
55
+ return torch.exp(1j * inner)
56
+ else:
57
+ return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
58
+
59
+
60
+ class ComplexLinear(nn.Module):
61
+ """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
62
+ def __init__(self, input_dim, output_dim, complex_valued):
63
+ super().__init__()
64
+ self.complex_valued = complex_valued
65
+ if self.complex_valued:
66
+ self.re = nn.Linear(input_dim, output_dim)
67
+ self.im = nn.Linear(input_dim, output_dim)
68
+ else:
69
+ self.lin = nn.Linear(input_dim, output_dim)
70
+
71
+ def forward(self, x):
72
+ if self.complex_valued:
73
+ return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
74
+ else:
75
+ return self.lin(x)
76
+
77
+
78
+ class FeatureMapDense(nn.Module):
79
+ """A fully connected layer that reshapes outputs to feature maps."""
80
+
81
+ def __init__(self, input_dim, output_dim, complex_valued=False):
82
+ super().__init__()
83
+ self.complex_valued = complex_valued
84
+ self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
85
+
86
+ def forward(self, x):
87
+ return self.dense(x)[..., None, None]
88
+
89
+
90
+ def torch_complex_from_reim(re, im):
91
+ return torch.view_as_complex(torch.stack([re, im], dim=-1))
92
+
93
+
94
+ class ArgsComplexMultiplicationWrapper(nn.Module):
95
+ """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
96
+
97
+ Make a complex-valued module `F` from a real-valued module `f` by applying
98
+ complex multiplication rules:
99
+
100
+ F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
101
+
102
+ where `f1`, `f2` are instances of `f` that do *not* share weights.
103
+
104
+ Args:
105
+ module_cls (callable): A class or function that returns a Torch module/functional.
106
+ Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
107
+ to construct the real and imaginary component modules.
108
+ """
109
+
110
+ def __init__(self, module_cls, *args, **kwargs):
111
+ super().__init__()
112
+ self.re_module = module_cls(*args, **kwargs)
113
+ self.im_module = module_cls(*args, **kwargs)
114
+
115
+ def forward(self, x, *args, **kwargs):
116
+ return torch_complex_from_reim(
117
+ self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
118
+ self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
119
+ )
120
+
121
+
122
+ ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
123
+ ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
sgmse/data_module.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from os.path import join
3
+ import torch
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import Dataset
6
+ from torch.utils.data import DataLoader
7
+ from glob import glob
8
+ from torchaudio import load
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def get_window(window_type, window_length):
14
+ if window_type == 'sqrthann':
15
+ return torch.sqrt(torch.hann_window(window_length, periodic=True))
16
+ elif window_type == 'hann':
17
+ return torch.hann_window(window_length, periodic=True)
18
+ else:
19
+ raise NotImplementedError(f"Window type {window_type} not implemented!")
20
+
21
+
22
+ class Specs(Dataset):
23
+ def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames,
24
+ format='default', normalize="noisy", spec_transform=None,
25
+ stft_kwargs=None, **ignored_kwargs):
26
+
27
+ # Read file paths according to file naming format.
28
+ if format == "default":
29
+ self.clean_files = []
30
+ self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
31
+ self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
32
+ self.noisy_files = []
33
+ self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
34
+ self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
35
+ elif format == "reverb":
36
+ self.clean_files = []
37
+ self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
38
+ self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
39
+ self.noisy_files = []
40
+ self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
41
+ self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
42
+ else:
43
+ # Feel free to add your own directory format
44
+ raise NotImplementedError(f"Directory format {format} unknown!")
45
+
46
+ self.dummy = dummy
47
+ self.num_frames = num_frames
48
+ self.shuffle_spec = shuffle_spec
49
+ self.normalize = normalize
50
+ self.spec_transform = spec_transform
51
+
52
+ assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
53
+ self.stft_kwargs = stft_kwargs
54
+ self.hop_length = self.stft_kwargs["hop_length"]
55
+ assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
56
+
57
+ def __getitem__(self, i):
58
+ x, _ = load(self.clean_files[i])
59
+ y, _ = load(self.noisy_files[i])
60
+
61
+ # formula applies for center=True
62
+ target_len = (self.num_frames - 1) * self.hop_length
63
+ current_len = x.size(-1)
64
+ pad = max(target_len - current_len, 0)
65
+ if pad == 0:
66
+ # extract random part of the audio file
67
+ if self.shuffle_spec:
68
+ start = int(np.random.uniform(0, current_len-target_len))
69
+ else:
70
+ start = int((current_len-target_len)/2)
71
+ x = x[..., start:start+target_len]
72
+ y = y[..., start:start+target_len]
73
+ else:
74
+ # pad audio if the length T is smaller than num_frames
75
+ x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
76
+ y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
77
+
78
+ # normalize w.r.t to the noisy or the clean signal or not at all
79
+ # to ensure same clean signal power in x and y.
80
+ if self.normalize == "noisy":
81
+ normfac = y.abs().max()
82
+ elif self.normalize == "clean":
83
+ normfac = x.abs().max()
84
+ elif self.normalize == "not":
85
+ normfac = 1.0
86
+ x = x / normfac
87
+ y = y / normfac
88
+
89
+ X = torch.stft(x, **self.stft_kwargs)
90
+ Y = torch.stft(y, **self.stft_kwargs)
91
+
92
+ X, Y = self.spec_transform(X), self.spec_transform(Y)
93
+ return X, Y
94
+
95
+ def __len__(self):
96
+ if self.dummy:
97
+ # for debugging shrink the data set size
98
+ return int(len(self.clean_files)/200)
99
+ else:
100
+ return len(self.clean_files)
101
+
102
+
103
+ class SpecsDataModule(pl.LightningDataModule):
104
+ @staticmethod
105
+ def add_argparse_args(parser):
106
+ parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
107
+ parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
108
+ parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
109
+ parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins
110
+ parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
111
+ parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
112
+ parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
113
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to use for DataLoaders. 4 by default.")
114
+ parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
115
+ parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
116
+ parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
117
+ parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
118
+ parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
119
+ return parser
120
+
121
+ def __init__(
122
+ self, base_dir, format='default', batch_size=8,
123
+ n_fft=510, hop_length=128, num_frames=256, window='hann',
124
+ num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
125
+ gpu=True, normalize='noisy', transform_type="exponent", **kwargs
126
+ ):
127
+ super().__init__()
128
+ self.base_dir = base_dir
129
+ self.format = format
130
+ self.batch_size = batch_size
131
+ self.n_fft = n_fft
132
+ self.hop_length = hop_length
133
+ self.num_frames = num_frames
134
+ self.window = get_window(window, self.n_fft)
135
+ self.windows = {}
136
+ self.num_workers = num_workers
137
+ self.dummy = dummy
138
+ self.spec_factor = spec_factor
139
+ self.spec_abs_exponent = spec_abs_exponent
140
+ self.gpu = gpu
141
+ self.normalize = normalize
142
+ self.transform_type = transform_type
143
+ self.kwargs = kwargs
144
+
145
+ def setup(self, stage=None):
146
+ specs_kwargs = dict(
147
+ stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
148
+ spec_transform=self.spec_fwd, **self.kwargs
149
+ )
150
+ if stage == 'fit' or stage is None:
151
+ self.train_set = Specs(data_dir=self.base_dir, subset='train',
152
+ dummy=self.dummy, shuffle_spec=True, format=self.format,
153
+ normalize=self.normalize, **specs_kwargs)
154
+ self.valid_set = Specs(data_dir=self.base_dir, subset='valid',
155
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
156
+ normalize=self.normalize, **specs_kwargs)
157
+ if stage == 'test' or stage is None:
158
+ self.test_set = Specs(data_dir=self.base_dir, subset='test',
159
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
160
+ normalize=self.normalize, **specs_kwargs)
161
+
162
+ def spec_fwd(self, spec):
163
+ if self.transform_type == "exponent":
164
+ if self.spec_abs_exponent != 1:
165
+ # only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
166
+ # and introduced numerical error
167
+ e = self.spec_abs_exponent
168
+ spec = spec.abs()**e * torch.exp(1j * spec.angle())
169
+ spec = spec * self.spec_factor
170
+ elif self.transform_type == "log":
171
+ spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
172
+ spec = spec * self.spec_factor
173
+ elif self.transform_type == "none":
174
+ spec = spec
175
+ return spec
176
+
177
+ def spec_back(self, spec):
178
+ if self.transform_type == "exponent":
179
+ spec = spec / self.spec_factor
180
+ if self.spec_abs_exponent != 1:
181
+ e = self.spec_abs_exponent
182
+ spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
183
+ elif self.transform_type == "log":
184
+ spec = spec / self.spec_factor
185
+ spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
186
+ elif self.transform_type == "none":
187
+ spec = spec
188
+ return spec
189
+
190
+ @property
191
+ def stft_kwargs(self):
192
+ return {**self.istft_kwargs, "return_complex": True}
193
+
194
+ @property
195
+ def istft_kwargs(self):
196
+ return dict(
197
+ n_fft=self.n_fft, hop_length=self.hop_length,
198
+ window=self.window, center=True
199
+ )
200
+
201
+ def _get_window(self, x):
202
+ """
203
+ Retrieve an appropriate window for the given tensor x, matching the device.
204
+ Caches the retrieved windows so that only one window tensor will be allocated per device.
205
+ """
206
+ window = self.windows.get(x.device, None)
207
+ if window is None:
208
+ window = self.window.to(x.device)
209
+ self.windows[x.device] = window
210
+ return window
211
+
212
+ def stft(self, sig):
213
+ window = self._get_window(sig)
214
+ return torch.stft(sig, **{**self.stft_kwargs, "window": window})
215
+
216
+ def istft(self, spec, length=None):
217
+ window = self._get_window(spec)
218
+ return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
219
+
220
+ def train_dataloader(self):
221
+ return DataLoader(
222
+ self.train_set, batch_size=self.batch_size,
223
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
224
+ )
225
+
226
+ def val_dataloader(self):
227
+ return DataLoader(
228
+ self.valid_set, batch_size=self.batch_size,
229
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
230
+ )
231
+
232
+ def test_dataloader(self):
233
+ return DataLoader(
234
+ self.test_set, batch_size=self.batch_size,
235
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
236
+ )
sgmse/model.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from math import ceil
3
+ import warnings
4
+
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ from torch_ema import ExponentialMovingAverage
8
+
9
+ from sgmse import sampling
10
+ from sgmse.sdes import SDERegistry
11
+ from sgmse.backbones import BackboneRegistry
12
+ from sgmse.util.inference import evaluate_model
13
+ from sgmse.util.other import pad_spec
14
+
15
+
16
+ class ScoreModel(pl.LightningModule):
17
+ @staticmethod
18
+ def add_argparse_args(parser):
19
+ parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
20
+ parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
21
+ parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
22
+ parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
23
+ parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
24
+ return parser
25
+
26
+ def __init__(
27
+ self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
28
+ num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
29
+ ):
30
+ """
31
+ Create a new ScoreModel.
32
+
33
+ Args:
34
+ backbone: Backbone DNN that serves as a score-based model.
35
+ sde: The SDE that defines the diffusion process.
36
+ lr: The learning rate of the optimizer. (1e-4 by default).
37
+ ema_decay: The decay constant of the parameter EMA (0.999 by default).
38
+ t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
39
+ loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
40
+ """
41
+ super().__init__()
42
+ # Initialize Backbone DNN
43
+ self.backbone = backbone
44
+ dnn_cls = BackboneRegistry.get_by_name(backbone)
45
+ self.dnn = dnn_cls(**kwargs)
46
+ # Initialize SDE
47
+ sde_cls = SDERegistry.get_by_name(sde)
48
+ self.sde = sde_cls(**kwargs)
49
+ # Store hyperparams and save them
50
+ self.lr = lr
51
+ self.ema_decay = ema_decay
52
+ self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
53
+ self._error_loading_ema = False
54
+ self.t_eps = t_eps
55
+ self.loss_type = loss_type
56
+ self.num_eval_files = num_eval_files
57
+
58
+ self.save_hyperparameters(ignore=['no_wandb'])
59
+ self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
60
+
61
+ def configure_optimizers(self):
62
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
63
+ return optimizer
64
+
65
+ def optimizer_step(self, *args, **kwargs):
66
+ # Method overridden so that the EMA params are updated after each optimizer step
67
+ super().optimizer_step(*args, **kwargs)
68
+ self.ema.update(self.parameters())
69
+
70
+ # on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
71
+ def on_load_checkpoint(self, checkpoint):
72
+ ema = checkpoint.get('ema', None)
73
+ if ema is not None:
74
+ self.ema.load_state_dict(checkpoint['ema'])
75
+ else:
76
+ self._error_loading_ema = True
77
+ warnings.warn("EMA state_dict not found in checkpoint!")
78
+
79
+ def on_save_checkpoint(self, checkpoint):
80
+ checkpoint['ema'] = self.ema.state_dict()
81
+
82
+ def train(self, mode, no_ema=False):
83
+ res = super().train(mode) # call the standard `train` method with the given mode
84
+ if not self._error_loading_ema:
85
+ if mode == False and not no_ema:
86
+ # eval
87
+ self.ema.store(self.parameters()) # store current params in EMA
88
+ self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
89
+ else:
90
+ # train
91
+ if self.ema.collected_params is not None:
92
+ self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
93
+ return res
94
+
95
+ def eval(self, no_ema=False):
96
+ return self.train(False, no_ema=no_ema)
97
+
98
+ def _loss(self, err):
99
+ if self.loss_type == 'mse':
100
+ losses = torch.square(err.abs())
101
+ elif self.loss_type == 'mae':
102
+ losses = err.abs()
103
+ # taken from reduce_op function: sum over channels and position and mean over batch dim
104
+ # presumably only important for absolute loss number, not for gradients
105
+ loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
106
+ return loss
107
+
108
+ def _step(self, batch, batch_idx):
109
+ x, y = batch
110
+ t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
111
+ mean, std = self.sde.marginal_prob(x, t, y)
112
+ z = torch.randn_like(x) # i.i.d. normal distributed with var=0.5
113
+ sigmas = std[:, None, None, None]
114
+ perturbed_data = mean + sigmas * z
115
+ score = self(perturbed_data, t, y)
116
+ err = score * sigmas + z
117
+ loss = self._loss(err)
118
+ return loss
119
+
120
+ def training_step(self, batch, batch_idx):
121
+ loss = self._step(batch, batch_idx)
122
+ self.log('train_loss', loss, on_step=True, on_epoch=True)
123
+ return loss
124
+
125
+ def validation_step(self, batch, batch_idx):
126
+ loss = self._step(batch, batch_idx)
127
+ self.log('valid_loss', loss, on_step=False, on_epoch=True)
128
+
129
+ # Evaluate speech enhancement performance
130
+ if batch_idx == 0 and self.num_eval_files != 0:
131
+ pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
132
+ self.log('pesq', pesq, on_step=False, on_epoch=True)
133
+ self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
134
+ self.log('estoi', estoi, on_step=False, on_epoch=True)
135
+
136
+ return loss
137
+
138
+ def forward(self, x, t, y):
139
+ # Concatenate y as an extra channel
140
+ dnn_input = torch.cat([x, y], dim=1)
141
+
142
+ # the minus is most likely unimportant here - taken from Song's repo
143
+ score = -self.dnn(dnn_input, t)
144
+ return score
145
+
146
+ def to(self, *args, **kwargs):
147
+ """Override PyTorch .to() to also transfer the EMA of the model weights"""
148
+ self.ema.to(*args, **kwargs)
149
+ return super().to(*args, **kwargs)
150
+
151
+ def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
152
+ N = self.sde.N if N is None else N
153
+ sde = self.sde.copy()
154
+ sde.N = N
155
+
156
+ kwargs = {"eps": self.t_eps, **kwargs}
157
+ if minibatch is None:
158
+ return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
159
+ else:
160
+ M = y.shape[0]
161
+ def batched_sampling_fn():
162
+ samples, ns = [], []
163
+ for i in range(int(ceil(M / minibatch))):
164
+ y_mini = y[i*minibatch:(i+1)*minibatch]
165
+ sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
166
+ sample, n = sampler()
167
+ samples.append(sample)
168
+ ns.append(n)
169
+ samples = torch.cat(samples, dim=0)
170
+ return samples, ns
171
+ return batched_sampling_fn
172
+
173
+ def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
174
+ N = self.sde.N if N is None else N
175
+ sde = self.sde.copy()
176
+ sde.N = N
177
+
178
+ kwargs = {"eps": self.t_eps, **kwargs}
179
+ if minibatch is None:
180
+ return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
181
+ else:
182
+ M = y.shape[0]
183
+ def batched_sampling_fn():
184
+ samples, ns = [], []
185
+ for i in range(int(ceil(M / minibatch))):
186
+ y_mini = y[i*minibatch:(i+1)*minibatch]
187
+ sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
188
+ sample, n = sampler()
189
+ samples.append(sample)
190
+ ns.append(n)
191
+ samples = torch.cat(samples, dim=0)
192
+ return sample, ns
193
+ return batched_sampling_fn
194
+
195
+ def train_dataloader(self):
196
+ return self.data_module.train_dataloader()
197
+
198
+ def val_dataloader(self):
199
+ return self.data_module.val_dataloader()
200
+
201
+ def test_dataloader(self):
202
+ return self.data_module.test_dataloader()
203
+
204
+ def setup(self, stage=None):
205
+ return self.data_module.setup(stage=stage)
206
+
207
+ def to_audio(self, spec, length=None):
208
+ return self._istft(self._backward_transform(spec), length)
209
+
210
+ def _forward_transform(self, spec):
211
+ return self.data_module.spec_fwd(spec)
212
+
213
+ def _backward_transform(self, spec):
214
+ return self.data_module.spec_back(spec)
215
+
216
+ def _stft(self, sig):
217
+ return self.data_module.stft(sig)
218
+
219
+ def _istft(self, spec, length=None):
220
+ return self.data_module.istft(spec, length)
221
+
222
+ def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
223
+ corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
224
+ **kwargs
225
+ ):
226
+ """
227
+ One-call speech enhancement of noisy speech `y`, for convenience.
228
+ """
229
+ sr=16000
230
+ start = time.time()
231
+ T_orig = y.size(1)
232
+ norm_factor = y.abs().max().item()
233
+ y = y / norm_factor
234
+ Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
235
+ Y = pad_spec(Y)
236
+ if sampler_type == "pc":
237
+ sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N,
238
+ corrector_steps=corrector_steps, snr=snr, intermediate=False,
239
+ **kwargs)
240
+ elif sampler_type == "ode":
241
+ sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
242
+ else:
243
+ print("{} is not a valid sampler type!".format(sampler_type))
244
+ sample, nfe = sampler()
245
+ x_hat = self.to_audio(sample.squeeze(), T_orig)
246
+ x_hat = x_hat * norm_factor
247
+ x_hat = x_hat.squeeze().cpu().numpy()
248
+ end = time.time()
249
+ if timeit:
250
+ rtf = (end-start)/(len(x_hat)/sr)
251
+ return x_hat, nfe, rtf
252
+ else:
253
+ return x_hat
sgmse/sampling/__init__.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
2
+ """Various sampling methods."""
3
+ from scipy import integrate
4
+ import torch
5
+
6
+ from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
7
+ from .correctors import Corrector, CorrectorRegistry
8
+
9
+
10
+ __all__ = [
11
+ 'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
12
+ 'get_sampler'
13
+ ]
14
+
15
+
16
+ def to_flattened_numpy(x):
17
+ """Flatten a torch tensor `x` and convert it to numpy."""
18
+ return x.detach().cpu().numpy().reshape((-1,))
19
+
20
+
21
+ def from_flattened_numpy(x, shape):
22
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
23
+ return torch.from_numpy(x.reshape(shape))
24
+
25
+
26
+ def get_pc_sampler(
27
+ predictor_name, corrector_name, sde, score_fn, y,
28
+ denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
29
+ intermediate=False, **kwargs
30
+ ):
31
+ """Create a Predictor-Corrector (PC) sampler.
32
+
33
+ Args:
34
+ predictor_name: The name of a registered `sampling.Predictor`.
35
+ corrector_name: The name of a registered `sampling.Corrector`.
36
+ sde: An `sdes.SDE` object representing the forward SDE.
37
+ score_fn: A function (typically learned model) that predicts the score.
38
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
39
+ denoise: If `True`, add one-step denoising to the final samples.
40
+ eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
41
+ snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
42
+ N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
43
+
44
+ Returns:
45
+ A sampling function that returns samples and the number of function evaluations during sampling.
46
+ """
47
+ predictor_cls = PredictorRegistry.get_by_name(predictor_name)
48
+ corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
49
+ predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
50
+ corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
51
+
52
+ def pc_sampler():
53
+ """The PC sampler function."""
54
+ with torch.no_grad():
55
+ xt = sde.prior_sampling(y.shape, y).to(y.device)
56
+ timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
57
+ for i in range(sde.N):
58
+ t = timesteps[i]
59
+ if i != len(timesteps) - 1:
60
+ stepsize = t - timesteps[i+1]
61
+ else:
62
+ stepsize = timesteps[-1] # from eps to 0
63
+ vec_t = torch.ones(y.shape[0], device=y.device) * t
64
+ xt, xt_mean = corrector.update_fn(xt, vec_t, y)
65
+ xt, xt_mean = predictor.update_fn(xt, vec_t, y, stepsize)
66
+ x_result = xt_mean if denoise else xt
67
+ ns = sde.N * (corrector.n_steps + 1)
68
+ return x_result, ns
69
+
70
+ return pc_sampler
71
+
72
+
73
+ def get_ode_sampler(
74
+ sde, score_fn, y, inverse_scaler=None,
75
+ denoise=True, rtol=1e-5, atol=1e-5,
76
+ method='RK45', eps=3e-2, device='cuda', **kwargs
77
+ ):
78
+ """Probability flow ODE sampler with the black-box ODE solver.
79
+
80
+ Args:
81
+ sde: An `sdes.SDE` object representing the forward SDE.
82
+ score_fn: A function (typically learned model) that predicts the score.
83
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
84
+ inverse_scaler: The inverse data normalizer.
85
+ denoise: If `True`, add one-step denoising to final samples.
86
+ rtol: A `float` number. The relative tolerance level of the ODE solver.
87
+ atol: A `float` number. The absolute tolerance level of the ODE solver.
88
+ method: A `str`. The algorithm used for the black-box ODE solver.
89
+ See the documentation of `scipy.integrate.solve_ivp`.
90
+ eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
91
+ device: PyTorch device.
92
+
93
+ Returns:
94
+ A sampling function that returns samples and the number of function evaluations during sampling.
95
+ """
96
+ predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
97
+ rsde = sde.reverse(score_fn, probability_flow=True)
98
+
99
+ def denoise_update_fn(x):
100
+ vec_eps = torch.ones(x.shape[0], device=x.device) * eps
101
+ _, x = predictor.update_fn(x, vec_eps, y)
102
+ return x
103
+
104
+ def drift_fn(x, t, y):
105
+ """Get the drift function of the reverse-time SDE."""
106
+ return rsde.sde(x, t, y)[0]
107
+
108
+ def ode_sampler(z=None, **kwargs):
109
+ """The probability flow ODE sampler with black-box ODE solver.
110
+
111
+ Args:
112
+ model: A score model.
113
+ z: If present, generate samples from latent code `z`.
114
+ Returns:
115
+ samples, number of function evaluations.
116
+ """
117
+ with torch.no_grad():
118
+ # If not represent, sample the latent code from the prior distibution of the SDE.
119
+ x = sde.prior_sampling(y.shape, y).to(device)
120
+
121
+ def ode_func(t, x):
122
+ x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
123
+ vec_t = torch.ones(y.shape[0], device=x.device) * t
124
+ drift = drift_fn(x, vec_t, y)
125
+ return to_flattened_numpy(drift)
126
+
127
+ # Black-box ODE solver for the probability flow ODE
128
+ solution = integrate.solve_ivp(
129
+ ode_func, (sde.T, eps), to_flattened_numpy(x),
130
+ rtol=rtol, atol=atol, method=method, **kwargs
131
+ )
132
+ nfe = solution.nfev
133
+ x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
134
+
135
+ # Denoising is equivalent to running one predictor step without adding noise
136
+ if denoise:
137
+ x = denoise_update_fn(x)
138
+
139
+ if inverse_scaler is not None:
140
+ x = inverse_scaler(x)
141
+ return x, nfe
142
+
143
+ return ode_sampler
sgmse/sampling/correctors.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import torch
3
+
4
+ from sgmse import sdes
5
+ from sgmse.util.registry import Registry
6
+
7
+
8
+ CorrectorRegistry = Registry("Corrector")
9
+
10
+
11
+ class Corrector(abc.ABC):
12
+ """The abstract class for a corrector algorithm."""
13
+
14
+ def __init__(self, sde, score_fn, snr, n_steps):
15
+ super().__init__()
16
+ self.rsde = sde.reverse(score_fn)
17
+ self.score_fn = score_fn
18
+ self.snr = snr
19
+ self.n_steps = n_steps
20
+
21
+ @abc.abstractmethod
22
+ def update_fn(self, x, t, *args):
23
+ """One update of the corrector.
24
+
25
+ Args:
26
+ x: A PyTorch tensor representing the current state
27
+ t: A PyTorch tensor representing the current time step.
28
+ *args: Possibly additional arguments, in particular `y` for OU processes
29
+
30
+ Returns:
31
+ x: A PyTorch tensor of the next state.
32
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
33
+ """
34
+ pass
35
+
36
+
37
+ @CorrectorRegistry.register(name='langevin')
38
+ class LangevinCorrector(Corrector):
39
+ def __init__(self, sde, score_fn, snr, n_steps):
40
+ super().__init__(sde, score_fn, snr, n_steps)
41
+ self.score_fn = score_fn
42
+ self.n_steps = n_steps
43
+ self.snr = snr
44
+
45
+ def update_fn(self, x, t, *args):
46
+ target_snr = self.snr
47
+ for _ in range(self.n_steps):
48
+ grad = self.score_fn(x, t, *args)
49
+ noise = torch.randn_like(x)
50
+ grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
51
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
52
+ step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
53
+ x_mean = x + step_size[:, None, None, None] * grad
54
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
55
+
56
+ return x, x_mean
57
+
58
+
59
+ @CorrectorRegistry.register(name='ald')
60
+ class AnnealedLangevinDynamics(Corrector):
61
+ """The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
62
+ def __init__(self, sde, score_fn, snr, n_steps):
63
+ super().__init__(sde, score_fn, snr, n_steps)
64
+ if not isinstance(sde, (sdes.OUVESDE,)):
65
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
66
+ self.sde = sde
67
+ self.score_fn = score_fn
68
+ self.snr = snr
69
+ self.n_steps = n_steps
70
+
71
+ def update_fn(self, x, t, *args):
72
+ n_steps = self.n_steps
73
+ target_snr = self.snr
74
+ std = self.sde.marginal_prob(x, t, *args)[1]
75
+
76
+ for _ in range(n_steps):
77
+ grad = self.score_fn(x, t, *args)
78
+ noise = torch.randn_like(x)
79
+ step_size = (target_snr * std) ** 2 * 2
80
+ x_mean = x + step_size[:, None, None, None] * grad
81
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
82
+
83
+ return x, x_mean
84
+
85
+
86
+ @CorrectorRegistry.register(name='none')
87
+ class NoneCorrector(Corrector):
88
+ """An empty corrector that does nothing."""
89
+
90
+ def __init__(self, *args, **kwargs):
91
+ self.snr = 0
92
+ self.n_steps = 0
93
+ pass
94
+
95
+ def update_fn(self, x, t, *args):
96
+ return x, x
sgmse/sampling/predictors.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from sgmse.util.registry import Registry
7
+
8
+
9
+ PredictorRegistry = Registry("Predictor")
10
+
11
+
12
+ class Predictor(abc.ABC):
13
+ """The abstract class for a predictor algorithm."""
14
+
15
+ def __init__(self, sde, score_fn, probability_flow=False):
16
+ super().__init__()
17
+ self.sde = sde
18
+ self.rsde = sde.reverse(score_fn)
19
+ self.score_fn = score_fn
20
+ self.probability_flow = probability_flow
21
+
22
+ @abc.abstractmethod
23
+ def update_fn(self, x, t, *args):
24
+ """One update of the predictor.
25
+
26
+ Args:
27
+ x: A PyTorch tensor representing the current state
28
+ t: A Pytorch tensor representing the current time step.
29
+ *args: Possibly additional arguments, in particular `y` for OU processes
30
+
31
+ Returns:
32
+ x: A PyTorch tensor of the next state.
33
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
34
+ """
35
+ pass
36
+
37
+ def debug_update_fn(self, x, t, *args):
38
+ raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
39
+
40
+
41
+ @PredictorRegistry.register('euler_maruyama')
42
+ class EulerMaruyamaPredictor(Predictor):
43
+ def __init__(self, sde, score_fn, probability_flow=False):
44
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
45
+
46
+ def update_fn(self, x, t, *args):
47
+ dt = -1. / self.rsde.N
48
+ z = torch.randn_like(x)
49
+ f, g = self.rsde.sde(x, t, *args)
50
+ x_mean = x + f * dt
51
+ x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
52
+ return x, x_mean
53
+
54
+
55
+ @PredictorRegistry.register('reverse_diffusion')
56
+ class ReverseDiffusionPredictor(Predictor):
57
+ def __init__(self, sde, score_fn, probability_flow=False):
58
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
59
+
60
+ def update_fn(self, x, t, y, stepsize):
61
+ f, g = self.rsde.discretize(x, t, y, stepsize)
62
+ z = torch.randn_like(x)
63
+ x_mean = x - f
64
+ x = x_mean + g[:, None, None, None] * z
65
+ return x, x_mean
66
+
67
+
68
+ @PredictorRegistry.register('none')
69
+ class NonePredictor(Predictor):
70
+ """An empty predictor that does nothing."""
71
+
72
+ def __init__(self, *args, **kwargs):
73
+ pass
74
+
75
+ def update_fn(self, x, t, *args):
76
+ return x, x
sgmse/sdes.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
3
+
4
+ Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
5
+ """
6
+ import abc
7
+ import warnings
8
+
9
+ import numpy as np
10
+ from sgmse.util.tensors import batch_broadcast
11
+ import torch
12
+
13
+ from sgmse.util.registry import Registry
14
+
15
+
16
+ SDERegistry = Registry("SDE")
17
+
18
+
19
+ class SDE(abc.ABC):
20
+ """SDE abstract class. Functions are designed for a mini-batch of inputs."""
21
+
22
+ def __init__(self, N):
23
+ """Construct an SDE.
24
+
25
+ Args:
26
+ N: number of discretization time steps.
27
+ """
28
+ super().__init__()
29
+ self.N = N
30
+
31
+ @property
32
+ @abc.abstractmethod
33
+ def T(self):
34
+ """End time of the SDE."""
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def sde(self, x, t, *args):
39
+ pass
40
+
41
+ @abc.abstractmethod
42
+ def marginal_prob(self, x, t, *args):
43
+ """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
44
+ pass
45
+
46
+ @abc.abstractmethod
47
+ def prior_sampling(self, shape, *args):
48
+ """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
49
+ pass
50
+
51
+ @abc.abstractmethod
52
+ def prior_logp(self, z):
53
+ """Compute log-density of the prior distribution.
54
+
55
+ Useful for computing the log-likelihood via probability flow ODE.
56
+
57
+ Args:
58
+ z: latent code
59
+ Returns:
60
+ log probability density
61
+ """
62
+ pass
63
+
64
+ @staticmethod
65
+ @abc.abstractmethod
66
+ def add_argparse_args(parent_parser):
67
+ """
68
+ Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
69
+ """
70
+ pass
71
+
72
+ def discretize(self, x, t, y, stepsize):
73
+ """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
74
+
75
+ Useful for reverse diffusion sampling and probabiliy flow sampling.
76
+ Defaults to Euler-Maruyama discretization.
77
+
78
+ Args:
79
+ x: a torch tensor
80
+ t: a torch float representing the time step (from 0 to `self.T`)
81
+
82
+ Returns:
83
+ f, G
84
+ """
85
+ dt = stepsize
86
+ drift, diffusion = self.sde(x, t, y)
87
+ f = drift * dt
88
+ G = diffusion * torch.sqrt(dt)
89
+ return f, G
90
+
91
+ def reverse(oself, score_model, probability_flow=False):
92
+ """Create the reverse-time SDE/ODE.
93
+
94
+ Args:
95
+ score_model: A function that takes x, t and y and returns the score.
96
+ probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
97
+ """
98
+ N = oself.N
99
+ T = oself.T
100
+ sde_fn = oself.sde
101
+ discretize_fn = oself.discretize
102
+
103
+ # Build the class for reverse-time SDE.
104
+ class RSDE(oself.__class__):
105
+ def __init__(self):
106
+ self.N = N
107
+ self.probability_flow = probability_flow
108
+
109
+ @property
110
+ def T(self):
111
+ return T
112
+
113
+ def sde(self, x, t, *args):
114
+ """Create the drift and diffusion functions for the reverse SDE/ODE."""
115
+ rsde_parts = self.rsde_parts(x, t, *args)
116
+ total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
117
+ return total_drift, diffusion
118
+
119
+ def rsde_parts(self, x, t, *args):
120
+ sde_drift, sde_diffusion = sde_fn(x, t, *args)
121
+ score = score_model(x, t, *args)
122
+ score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
123
+ diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
124
+ total_drift = sde_drift + score_drift
125
+ return {
126
+ 'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
127
+ 'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
128
+ }
129
+
130
+ def discretize(self, x, t, y, stepsize):
131
+ """Create discretized iteration rules for the reverse diffusion sampler."""
132
+ f, G = discretize_fn(x, t, y, stepsize)
133
+ rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
134
+ rev_G = torch.zeros_like(G) if self.probability_flow else G
135
+ return rev_f, rev_G
136
+
137
+ return RSDE()
138
+
139
+ @abc.abstractmethod
140
+ def copy(self):
141
+ pass
142
+
143
+
144
+ @SDERegistry.register("ouve")
145
+ class OUVESDE(SDE):
146
+ @staticmethod
147
+ def add_argparse_args(parser):
148
+ parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default")
149
+ parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
150
+ parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
151
+ parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
152
+ return parser
153
+
154
+ def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
155
+ """Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
156
+
157
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
158
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
159
+
160
+ dx = -theta (y-x) dt + sigma(t) dw
161
+
162
+ with
163
+
164
+ sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
165
+
166
+ Args:
167
+ theta: stiffness parameter.
168
+ sigma_min: smallest sigma.
169
+ sigma_max: largest sigma.
170
+ N: number of discretization steps
171
+ """
172
+ super().__init__(N)
173
+ self.theta = theta
174
+ self.sigma_min = sigma_min
175
+ self.sigma_max = sigma_max
176
+ self.logsig = np.log(self.sigma_max / self.sigma_min)
177
+ self.N = N
178
+
179
+ def copy(self):
180
+ return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
181
+
182
+ @property
183
+ def T(self):
184
+ return 1
185
+
186
+ def sde(self, x, t, y):
187
+ drift = self.theta * (y - x)
188
+ # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
189
+ # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
190
+ # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
191
+ # unless this sqrt(2*logsig) factor is included.
192
+ sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
193
+ diffusion = sigma * np.sqrt(2 * self.logsig)
194
+ return drift, diffusion
195
+
196
+ def _mean(self, x0, t, y):
197
+ theta = self.theta
198
+ exp_interp = torch.exp(-theta * t)[:, None, None, None]
199
+ return exp_interp * x0 + (1 - exp_interp) * y
200
+
201
+ def alpha(self, t):
202
+ return torch.exp(-self.theta * t)
203
+
204
+ def _std(self, t):
205
+ # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
206
+ sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
207
+ # could maybe replace the two torch.exp(... * t) terms here by cached values **t
208
+ return torch.sqrt(
209
+ (
210
+ sigma_min**2
211
+ * torch.exp(-2 * theta * t)
212
+ * (torch.exp(2 * (theta + logsig) * t) - 1)
213
+ * logsig
214
+ )
215
+ /
216
+ (theta + logsig)
217
+ )
218
+
219
+ def marginal_prob(self, x0, t, y):
220
+ return self._mean(x0, t, y), self._std(t)
221
+
222
+ def prior_sampling(self, shape, y):
223
+ if shape != y.shape:
224
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
225
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
226
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
227
+ return x_T
228
+
229
+ def prior_logp(self, z):
230
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
231
+
232
+
233
+ @SDERegistry.register("ouvp")
234
+ class OUVPSDE(SDE):
235
+ # !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
236
+ @staticmethod
237
+ def add_argparse_args(parser):
238
+ parser.add_argument("--sde-n", type=int, default=1000,
239
+ help="The number of timesteps in the SDE discretization. 1000 by default")
240
+ parser.add_argument("--beta-min", type=float, required=True,
241
+ help="The minimum beta to use.")
242
+ parser.add_argument("--beta-max", type=float, required=True,
243
+ help="The maximum beta to use.")
244
+ parser.add_argument("--stiffness", type=float, default=1,
245
+ help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
246
+ return parser
247
+
248
+ def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
249
+ """
250
+ !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
251
+
252
+ Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
253
+
254
+ dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
255
+
256
+ with
257
+
258
+ beta(t) = beta_min + t(beta_max - beta_min)
259
+
260
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
261
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
262
+
263
+ Args:
264
+ beta_min: smallest sigma.
265
+ beta_max: largest sigma.
266
+ stiffness: stiffness factor of the drift. 1 by default.
267
+ N: number of discretization steps
268
+ """
269
+ super().__init__(N)
270
+ self.beta_min = beta_min
271
+ self.beta_max = beta_max
272
+ self.stiffness = stiffness
273
+ self.N = N
274
+
275
+ def copy(self):
276
+ return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
277
+
278
+ @property
279
+ def T(self):
280
+ return 1
281
+
282
+ def _beta(self, t):
283
+ return self.beta_min + t * (self.beta_max - self.beta_min)
284
+
285
+ def sde(self, x, t, y):
286
+ drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
287
+ diffusion = torch.sqrt(self._beta(t))
288
+ return drift, diffusion
289
+
290
+ def _mean(self, x0, t, y):
291
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
292
+ x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
293
+ return y + x0y_fac * (x0 - y)
294
+
295
+ def _std(self, t):
296
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
297
+ return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
298
+
299
+ def marginal_prob(self, x0, t, y):
300
+ return self._mean(x0, t, y), self._std(t)
301
+
302
+ def prior_sampling(self, shape, y):
303
+ if shape != y.shape:
304
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
305
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
306
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
307
+ return x_T
308
+
309
+ def prior_logp(self, z):
310
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
sgmse/util/inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchaudio import load
3
+
4
+ from pesq import pesq
5
+ from pystoi import stoi
6
+
7
+ from .other import si_sdr, pad_spec
8
+
9
+ # Settings
10
+ sr = 16000
11
+ snr = 0.5
12
+ N = 30
13
+ corrector_steps = 1
14
+
15
+
16
+ def evaluate_model(model, num_eval_files):
17
+
18
+ clean_files = model.data_module.valid_set.clean_files
19
+ noisy_files = model.data_module.valid_set.noisy_files
20
+
21
+ # Select test files uniformly accros validation files
22
+ total_num_files = len(clean_files)
23
+ indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
24
+ clean_files = list(clean_files[i] for i in indices)
25
+ noisy_files = list(noisy_files[i] for i in indices)
26
+
27
+ _pesq = 0
28
+ _si_sdr = 0
29
+ _estoi = 0
30
+ # iterate over files
31
+ for (clean_file, noisy_file) in zip(clean_files, noisy_files):
32
+ # Load wavs
33
+ x, _ = load(clean_file)
34
+ y, _ = load(noisy_file)
35
+ T_orig = x.size(1)
36
+
37
+ # Normalize per utterance
38
+ norm_factor = y.abs().max()
39
+ y = y / norm_factor
40
+
41
+ # Prepare DNN input
42
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
43
+ Y = pad_spec(Y)
44
+ y = y * norm_factor
45
+
46
+ # Reverse sampling
47
+ sampler = model.get_pc_sampler(
48
+ 'reverse_diffusion', 'ald', Y.cuda(), N=N,
49
+ corrector_steps=corrector_steps, snr=snr)
50
+ sample, _ = sampler()
51
+
52
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
53
+ x_hat = x_hat * norm_factor
54
+
55
+ x_hat = x_hat.squeeze().cpu().numpy()
56
+ x = x.squeeze().cpu().numpy()
57
+ y = y.squeeze().cpu().numpy()
58
+
59
+ _si_sdr += si_sdr(x, x_hat)
60
+ _pesq += pesq(sr, x, x_hat, 'wb')
61
+ _estoi += stoi(x, x_hat, sr, extended=True)
62
+
63
+ return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
64
+
sgmse/util/other.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import scipy.stats
5
+ from scipy.signal import butter, sosfilt
6
+
7
+ from pesq import pesq
8
+ from pystoi import stoi
9
+
10
+
11
+ def si_sdr_components(s_hat, s, n):
12
+ # s_target
13
+ alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
14
+ s_target = alpha_s * s
15
+
16
+ # e_noise
17
+ alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
18
+ e_noise = alpha_n * n
19
+
20
+ # e_art
21
+ e_art = s_hat - s_target - e_noise
22
+
23
+ return s_target, e_noise, e_art
24
+
25
+ def energy_ratios(s_hat, s, n):
26
+ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
27
+
28
+ si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
29
+ si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
30
+ si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
31
+
32
+ return si_sdr, si_sir, si_sar
33
+
34
+ def mean_conf_int(data, confidence=0.95):
35
+ a = 1.0 * np.array(data)
36
+ n = len(a)
37
+ m, se = np.mean(a), scipy.stats.sem(a)
38
+ h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
39
+ return m, h
40
+
41
+ class Method():
42
+ def __init__(self, name, base_dir, metrics):
43
+ self.name = name
44
+ self.base_dir = base_dir
45
+ self.metrics = {}
46
+
47
+ for i in range(len(metrics)):
48
+ metric = metrics[i]
49
+ value = []
50
+ self.metrics[metric] = value
51
+
52
+ def append(self, matric, value):
53
+ self.metrics[matric].append(value)
54
+
55
+ def get_mean_ci(self, metric):
56
+ return mean_conf_int(np.array(self.metrics[metric]))
57
+
58
+ def hp_filter(signal, cut_off=80, order=10, sr=16000):
59
+ factor = cut_off /sr * 2
60
+ sos = butter(order, factor, 'hp', output='sos')
61
+ filtered = sosfilt(sos, signal)
62
+ return filtered
63
+
64
+ def si_sdr(s, s_hat):
65
+ alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
66
+ sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
67
+ alpha*s - s_hat)**2)
68
+ return sdr
69
+
70
+ def snr_dB(s,n):
71
+ s_power = 1/len(s)*np.sum(s**2)
72
+ n_power = 1/len(n)*np.sum(n**2)
73
+ snr_dB = 10*np.log10(s_power/n_power)
74
+ return snr_dB
75
+
76
+ def pad_spec(Y, mode="zero_pad"):
77
+ T = Y.size(3)
78
+ if T%64 !=0:
79
+ num_pad = 64-T%64
80
+ else:
81
+ num_pad = 0
82
+ if mode == "zero_pad":
83
+ pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
84
+ elif mode == "reflection":
85
+ pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
86
+ elif mode == "replication":
87
+ pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
88
+ else:
89
+ raise NotImplementedError("This function hasn't been implemented yet.")
90
+ return pad2d(Y)
91
+
92
+ def ensure_dir(file_path):
93
+ directory = file_path
94
+ if not os.path.exists(directory):
95
+ os.makedirs(directory)
96
+
97
+
98
+ def print_metrics(x, y, x_hat_list, labels, sr=16000):
99
+ _si_sdr_mix = si_sdr(x, y)
100
+ _pesq_mix = pesq(sr, x, y, 'wb')
101
+ _estoi_mix = stoi(x, y, sr, extended=True)
102
+ print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
103
+ for i, x_hat in enumerate(x_hat_list):
104
+ _si_sdr = si_sdr(x, x_hat)
105
+ _pesq = pesq(sr, x, x_hat, 'wb')
106
+ _estoi = stoi(x, x_hat, sr, extended=True)
107
+ print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
108
+
109
+ def mean_std(data):
110
+ data = data[~np.isnan(data)]
111
+ mean = np.mean(data)
112
+ std = np.std(data)
113
+ return mean, std
114
+
115
+ def print_mean_std(data, decimal=2):
116
+ data = np.array(data)
117
+ data = data[~np.isnan(data)]
118
+ mean = np.mean(data)
119
+ std = np.std(data)
120
+ if decimal == 2:
121
+ string = f'{mean:.2f} ± {std:.2f}'
122
+ elif decimal == 1:
123
+ string = f'{mean:.1f} ± {std:.1f}'
124
+ return string
125
+
126
+ def set_torch_cuda_arch_list():
127
+ if not torch.cuda.is_available():
128
+ print("CUDA is not available. No GPUs found.")
129
+ return
130
+
131
+ num_gpus = torch.cuda.device_count()
132
+ compute_capabilities = []
133
+
134
+ for i in range(num_gpus):
135
+ cc_major, cc_minor = torch.cuda.get_device_capability(i)
136
+ cc = f"{cc_major}.{cc_minor}"
137
+ compute_capabilities.append(cc)
138
+
139
+ cc_string = ";".join(compute_capabilities)
140
+ os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
141
+ print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
sgmse/util/registry.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Callable
3
+
4
+
5
+ class Registry:
6
+ def __init__(self, managed_thing: str):
7
+ """
8
+ Create a new registry.
9
+
10
+ Args:
11
+ managed_thing: A string describing what type of thing is managed by this registry. Will be used for
12
+ warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
13
+ """
14
+ self.managed_thing = managed_thing
15
+ self._registry = {}
16
+
17
+ def register(self, name: str) -> Callable:
18
+ def inner_wrapper(wrapped_class) -> Callable:
19
+ if name in self._registry:
20
+ warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
21
+ self._registry[name] = wrapped_class
22
+ return wrapped_class
23
+ return inner_wrapper
24
+
25
+ def get_by_name(self, name: str):
26
+ """Get a managed thing by name."""
27
+ if name in self._registry:
28
+ return self._registry[name]
29
+ else:
30
+ raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
31
+
32
+ def get_all_names(self):
33
+ """Get the list of things' names registered to this registry."""
34
+ return list(self._registry.keys())
sgmse/util/tensors.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def batch_broadcast(a, x):
2
+ """Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
3
+
4
+ if len(a.shape) != 1:
5
+ a = a.squeeze()
6
+ if len(a.shape) != 1:
7
+ raise ValueError(
8
+ f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
9
+ )
10
+
11
+ if a.shape[0] != x.shape[0] and a.shape[0] != 1:
12
+ raise ValueError(
13
+ f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
14
+
15
+ out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
16
+ return out
train.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import argparse
3
+ import pytorch_lightning as pl
4
+
5
+ from argparse import ArgumentParser
6
+ from pytorch_lightning.loggers import WandbLogger
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from os.path import join
9
+
10
+ # Set CUDA architecture list
11
+ from sgmse.util.other import set_torch_cuda_arch_list
12
+ set_torch_cuda_arch_list()
13
+
14
+ from sgmse.backbones.shared import BackboneRegistry
15
+ from sgmse.data_module import SpecsDataModule
16
+ from sgmse.sdes import SDERegistry
17
+ from sgmse.model import ScoreModel
18
+
19
+
20
+ def get_argparse_groups(parser):
21
+ groups = {}
22
+ for group in parser._action_groups:
23
+ group_dict = { a.dest: getattr(args, a.dest, None) for a in group._group_actions }
24
+ groups[group.title] = argparse.Namespace(**group_dict)
25
+ return groups
26
+
27
+
28
+ if __name__ == '__main__':
29
+ # throwaway parser for dynamic args - see https://stackoverflow.com/a/25320537/3090225
30
+ base_parser = ArgumentParser(add_help=False)
31
+ parser = ArgumentParser()
32
+ for parser_ in (base_parser, parser):
33
+ parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp")
34
+ parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve")
35
+ parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
36
+ parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
37
+ parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
38
+ parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
39
+
40
+ temp_args, _ = base_parser.parse_known_args()
41
+
42
+ # Add specific args for ScoreModel, pl.Trainer, the SDE class and backbone DNN class
43
+ backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone)
44
+ sde_class = SDERegistry.get_by_name(temp_args.sde)
45
+ trainer_parser = parser.add_argument_group("Trainer", description="Lightning Trainer")
46
+ trainer_parser.add_argument("--accelerator", type=str, default="gpu", help="Supports passing different accelerator types.")
47
+ trainer_parser.add_argument("--devices", default="auto", help="How many gpus to use.")
48
+ trainer_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients.")
49
+
50
+ ScoreModel.add_argparse_args(
51
+ parser.add_argument_group("ScoreModel", description=ScoreModel.__name__))
52
+ sde_class.add_argparse_args(
53
+ parser.add_argument_group("SDE", description=sde_class.__name__))
54
+ backbone_cls.add_argparse_args(
55
+ parser.add_argument_group("Backbone", description=backbone_cls.__name__))
56
+ # Add data module args
57
+ data_module_cls = SpecsDataModule
58
+ data_module_cls.add_argparse_args(
59
+ parser.add_argument_group("DataModule", description=data_module_cls.__name__))
60
+ # Parse args and separate into groups
61
+ args = parser.parse_args()
62
+ arg_groups = get_argparse_groups(parser)
63
+
64
+ # Initialize logger, trainer, model, datamodule
65
+ model = ScoreModel(
66
+ backbone=args.backbone, sde=args.sde, data_module_cls=data_module_cls,
67
+ **{
68
+ **vars(arg_groups['ScoreModel']),
69
+ **vars(arg_groups['SDE']),
70
+ **vars(arg_groups['Backbone']),
71
+ **vars(arg_groups['DataModule'])
72
+ }
73
+ )
74
+
75
+ # Set up logger configuration
76
+ if args.nolog:
77
+ logger = None
78
+ else:
79
+ logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs", name=args.wandb_name)
80
+ logger.experiment.log_code(".")
81
+
82
+ # Set up callbacks for logger
83
+ if logger != None:
84
+ callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
85
+ if args.num_eval_files:
86
+ checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
87
+ save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
88
+ checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
89
+ save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
90
+ callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
91
+ else:
92
+ callbacks = None
93
+
94
+ # Initialize the Trainer and the DataModule
95
+ trainer = pl.Trainer(
96
+ **vars(arg_groups['Trainer']),
97
+ strategy="ddp", logger=logger,
98
+ log_every_n_steps=10, num_sanity_val_steps=0,
99
+ callbacks=callbacks
100
+ )
101
+
102
+ # Train model
103
+ trainer.fit(model, ckpt_path=args.ckpt)