liuhuadai commited on
Commit
8e19fe0
1 Parent(s): 6e4c507

Upload 21 files

Browse files
vocoder/bigvgan/__init__.py ADDED
File without changes
vocoder/bigvgan/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (177 Bytes). View file
 
vocoder/bigvgan/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (143 Bytes). View file
 
vocoder/bigvgan/__pycache__/activations.cpython-37.pyc ADDED
Binary file (4.14 kB). View file
 
vocoder/bigvgan/__pycache__/activations.cpython-38.pyc ADDED
Binary file (4.05 kB). View file
 
vocoder/bigvgan/__pycache__/models.cpython-37.pyc ADDED
Binary file (13.7 kB). View file
 
vocoder/bigvgan/__pycache__/models.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
vocoder/bigvgan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
vocoder/bigvgan/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
vocoder/bigvgan/alias_free_torch/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (252 Bytes). View file
 
vocoder/bigvgan/alias_free_torch/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (218 Bytes). View file
 
vocoder/bigvgan/alias_free_torch/__pycache__/act.cpython-37.pyc ADDED
Binary file (1.04 kB). View file
 
vocoder/bigvgan/alias_free_torch/__pycache__/act.cpython-38.pyc ADDED
Binary file (1.03 kB). View file
 
vocoder/bigvgan/alias_free_torch/__pycache__/filter.cpython-37.pyc ADDED
Binary file (2.61 kB). View file
 
vocoder/bigvgan/alias_free_torch/__pycache__/filter.cpython-38.pyc ADDED
Binary file (2.61 kB). View file
 
vocoder/bigvgan/alias_free_torch/__pycache__/resample.cpython-37.pyc ADDED
Binary file (1.98 kB). View file
 
vocoder/bigvgan/alias_free_torch/__pycache__/resample.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
vocoder/bigvgan/alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12):
15
+ super().__init__()
16
+ self.up_ratio = up_ratio
17
+ self.down_ratio = down_ratio
18
+ self.act = activation
19
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
+
22
+ # x: [B,C,T]
23
+ def forward(self, x):
24
+ x = self.upsample(x)
25
+ x = self.act(x)
26
+ x = self.downsample(x)
27
+
28
+ return x
vocoder/bigvgan/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
vocoder/bigvgan/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
vocoder/bigvgan/models.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.nn as nn
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ import numpy as np
14
+ from .activations import Snake,SnakeBeta
15
+ from .alias_free_torch import *
16
+ import os
17
+ from omegaconf import OmegaConf
18
+
19
+ LRELU_SLOPE = 0.1
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return int((kernel_size*dilation - dilation)/2)
29
+
30
+ class AMPBlock1(torch.nn.Module):
31
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
32
+ super(AMPBlock1, self).__init__()
33
+ self.h = h
34
+
35
+ self.convs1 = nn.ModuleList([
36
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
37
+ padding=get_padding(kernel_size, dilation[0]))),
38
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
39
+ padding=get_padding(kernel_size, dilation[1]))),
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
41
+ padding=get_padding(kernel_size, dilation[2])))
42
+ ])
43
+ self.convs1.apply(init_weights)
44
+
45
+ self.convs2 = nn.ModuleList([
46
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
47
+ padding=get_padding(kernel_size, 1))),
48
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
49
+ padding=get_padding(kernel_size, 1))),
50
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
51
+ padding=get_padding(kernel_size, 1)))
52
+ ])
53
+ self.convs2.apply(init_weights)
54
+
55
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
56
+
57
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
58
+ self.activations = nn.ModuleList([
59
+ Activation1d(
60
+ activation=Snake(channels, alpha_logscale=h.snake_logscale))
61
+ for _ in range(self.num_layers)
62
+ ])
63
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
64
+ self.activations = nn.ModuleList([
65
+ Activation1d(
66
+ activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale))
67
+ for _ in range(self.num_layers)
68
+ ])
69
+ else:
70
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
71
+
72
+ def forward(self, x):
73
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
74
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
75
+ xt = a1(x)
76
+ xt = c1(xt)
77
+ xt = a2(xt)
78
+ xt = c2(xt)
79
+ x = xt + x
80
+
81
+ return x
82
+
83
+ def remove_weight_norm(self):
84
+ for l in self.convs1:
85
+ remove_weight_norm(l)
86
+ for l in self.convs2:
87
+ remove_weight_norm(l)
88
+
89
+
90
+ class AMPBlock2(torch.nn.Module):
91
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
92
+ super(AMPBlock2, self).__init__()
93
+ self.h = h
94
+
95
+ self.convs = nn.ModuleList([
96
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
97
+ padding=get_padding(kernel_size, dilation[0]))),
98
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
99
+ padding=get_padding(kernel_size, dilation[1])))
100
+ ])
101
+ self.convs.apply(init_weights)
102
+
103
+ self.num_layers = len(self.convs) # total number of conv layers
104
+
105
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
106
+ self.activations = nn.ModuleList([
107
+ Activation1d(
108
+ activation=Snake(channels, alpha_logscale=h.snake_logscale))
109
+ for _ in range(self.num_layers)
110
+ ])
111
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
112
+ self.activations = nn.ModuleList([
113
+ Activation1d(
114
+ activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale))
115
+ for _ in range(self.num_layers)
116
+ ])
117
+ else:
118
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
119
+
120
+ def forward(self, x):
121
+ for c, a in zip (self.convs, self.activations):
122
+ xt = a(x)
123
+ xt = c(xt)
124
+ x = xt + x
125
+
126
+ return x
127
+
128
+ def remove_weight_norm(self):
129
+ for l in self.convs:
130
+ remove_weight_norm(l)
131
+
132
+
133
+ class BigVGAN(torch.nn.Module):
134
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
135
+ def __init__(self, h):
136
+ super(BigVGAN, self).__init__()
137
+ self.h = h
138
+
139
+ self.num_kernels = len(h.resblock_kernel_sizes)
140
+ self.num_upsamples = len(h.upsample_rates)
141
+
142
+ # pre conv
143
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
144
+
145
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
146
+ resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
147
+
148
+ # transposed conv-based upsamplers. does not apply anti-aliasing
149
+ self.ups = nn.ModuleList()
150
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
151
+ self.ups.append(nn.ModuleList([
152
+ weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
153
+ h.upsample_initial_channel // (2 ** (i + 1)),
154
+ k, u, padding=(k - u) // 2))
155
+ ]))
156
+
157
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
158
+ self.resblocks = nn.ModuleList()
159
+ for i in range(len(self.ups)):
160
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
161
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
162
+ self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
163
+
164
+ # post conv
165
+ if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
166
+ activation_post = Snake(ch, alpha_logscale=h.snake_logscale)
167
+ self.activation_post = Activation1d(activation=activation_post)
168
+ elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
169
+ activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale)
170
+ self.activation_post = Activation1d(activation=activation_post)
171
+ else:
172
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
173
+
174
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
175
+
176
+ # weight initialization
177
+ for i in range(len(self.ups)):
178
+ self.ups[i].apply(init_weights)
179
+ self.conv_post.apply(init_weights)
180
+
181
+ def forward(self, x):
182
+ # pre conv
183
+ x = self.conv_pre(x)
184
+
185
+ for i in range(self.num_upsamples):
186
+ # upsampling
187
+ for i_up in range(len(self.ups[i])):
188
+ x = self.ups[i][i_up](x)
189
+ # AMP blocks
190
+ xs = None
191
+ for j in range(self.num_kernels):
192
+ if xs is None:
193
+ xs = self.resblocks[i * self.num_kernels + j](x)
194
+ else:
195
+ xs += self.resblocks[i * self.num_kernels + j](x)
196
+ x = xs / self.num_kernels
197
+
198
+ # post conv
199
+ x = self.activation_post(x)
200
+ x = self.conv_post(x)
201
+ x = torch.tanh(x)
202
+
203
+ return x
204
+
205
+ def remove_weight_norm(self):
206
+ print('Removing weight norm...')
207
+ for l in self.ups:
208
+ for l_i in l:
209
+ remove_weight_norm(l_i)
210
+ for l in self.resblocks:
211
+ l.remove_weight_norm()
212
+ remove_weight_norm(self.conv_pre)
213
+ remove_weight_norm(self.conv_post)
214
+
215
+
216
+ class DiscriminatorP(torch.nn.Module):
217
+ def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
218
+ super(DiscriminatorP, self).__init__()
219
+ self.period = period
220
+ self.d_mult = h.discriminator_channel_mult
221
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
222
+ self.convs = nn.ModuleList([
223
+ norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
224
+ norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
225
+ norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
226
+ norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
227
+ norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
228
+ ])
229
+ self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
230
+
231
+ def forward(self, x):
232
+ fmap = []
233
+
234
+ # 1d to 2d
235
+ b, c, t = x.shape
236
+ if t % self.period != 0: # pad first
237
+ n_pad = self.period - (t % self.period)
238
+ x = F.pad(x, (0, n_pad), "reflect")
239
+ t = t + n_pad
240
+ x = x.view(b, c, t // self.period, self.period)
241
+
242
+ for l in self.convs:
243
+ x = l(x)
244
+ x = F.leaky_relu(x, LRELU_SLOPE)
245
+ fmap.append(x)
246
+ x = self.conv_post(x)
247
+ fmap.append(x)
248
+ x = torch.flatten(x, 1, -1)
249
+
250
+ return x, fmap
251
+
252
+
253
+ class MultiPeriodDiscriminator(torch.nn.Module):
254
+ def __init__(self, h):
255
+ super(MultiPeriodDiscriminator, self).__init__()
256
+ self.mpd_reshapes = h.mpd_reshapes
257
+ print("mpd_reshapes: {}".format(self.mpd_reshapes))
258
+ discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
259
+ self.discriminators = nn.ModuleList(discriminators)
260
+
261
+ def forward(self, y, y_hat):
262
+ y_d_rs = []
263
+ y_d_gs = []
264
+ fmap_rs = []
265
+ fmap_gs = []
266
+ for i, d in enumerate(self.discriminators):
267
+ y_d_r, fmap_r = d(y)
268
+ y_d_g, fmap_g = d(y_hat)
269
+ y_d_rs.append(y_d_r)
270
+ fmap_rs.append(fmap_r)
271
+ y_d_gs.append(y_d_g)
272
+ fmap_gs.append(fmap_g)
273
+
274
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
275
+
276
+
277
+ class DiscriminatorR(nn.Module):
278
+ def __init__(self, cfg, resolution):
279
+ super().__init__()
280
+
281
+ self.resolution = resolution
282
+ assert len(self.resolution) == 3, \
283
+ "MRD layer requires list with len=3, got {}".format(self.resolution)
284
+ self.lrelu_slope = LRELU_SLOPE
285
+
286
+ norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
287
+ if hasattr(cfg, "mrd_use_spectral_norm"):
288
+ print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
289
+ norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
290
+ self.d_mult = cfg.discriminator_channel_mult
291
+ if hasattr(cfg, "mrd_channel_mult"):
292
+ print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
293
+ self.d_mult = cfg.mrd_channel_mult
294
+
295
+ self.convs = nn.ModuleList([
296
+ norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
297
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
298
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
299
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
300
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
301
+ ])
302
+ self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
303
+
304
+ def forward(self, x):
305
+ fmap = []
306
+
307
+ x = self.spectrogram(x)
308
+ x = x.unsqueeze(1)
309
+ for l in self.convs:
310
+ x = l(x)
311
+ x = F.leaky_relu(x, self.lrelu_slope)
312
+ fmap.append(x)
313
+ x = self.conv_post(x)
314
+ fmap.append(x)
315
+ x = torch.flatten(x, 1, -1)
316
+
317
+ return x, fmap
318
+
319
+ def spectrogram(self, x):
320
+ n_fft, hop_length, win_length = self.resolution
321
+ x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
322
+ x = x.squeeze(1)
323
+ x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
324
+ x = torch.view_as_real(x) # [B, F, TT, 2]
325
+ mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
326
+
327
+ return mag
328
+
329
+
330
+ class MultiResolutionDiscriminator(nn.Module):
331
+ def __init__(self, cfg, debug=False):
332
+ super().__init__()
333
+ self.resolutions = cfg.resolutions
334
+ assert len(self.resolutions) == 3,\
335
+ "MRD requires list of list with len=3, each element having a list with len=3. got {}".\
336
+ format(self.resolutions)
337
+ self.discriminators = nn.ModuleList(
338
+ [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
339
+ )
340
+
341
+ def forward(self, y, y_hat):
342
+ y_d_rs = []
343
+ y_d_gs = []
344
+ fmap_rs = []
345
+ fmap_gs = []
346
+
347
+ for i, d in enumerate(self.discriminators):
348
+ y_d_r, fmap_r = d(x=y)
349
+ y_d_g, fmap_g = d(x=y_hat)
350
+ y_d_rs.append(y_d_r)
351
+ fmap_rs.append(fmap_r)
352
+ y_d_gs.append(y_d_g)
353
+ fmap_gs.append(fmap_g)
354
+
355
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
356
+
357
+
358
+ def feature_loss(fmap_r, fmap_g):
359
+ loss = 0
360
+ for dr, dg in zip(fmap_r, fmap_g):
361
+ for rl, gl in zip(dr, dg):
362
+ loss += torch.mean(torch.abs(rl - gl))
363
+
364
+ return loss*2
365
+
366
+
367
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
368
+ loss = 0
369
+ r_losses = []
370
+ g_losses = []
371
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
372
+ r_loss = torch.mean((1-dr)**2)
373
+ g_loss = torch.mean(dg**2)
374
+ loss += (r_loss + g_loss)
375
+ r_losses.append(r_loss.item())
376
+ g_losses.append(g_loss.item())
377
+
378
+ return loss, r_losses, g_losses
379
+
380
+
381
+ def generator_loss(disc_outputs):
382
+ loss = 0
383
+ gen_losses = []
384
+ for dg in disc_outputs:
385
+ l = torch.mean((1-dg)**2)
386
+ gen_losses.append(l)
387
+ loss += l
388
+
389
+ return loss, gen_losses
390
+
391
+
392
+
393
+ class VocoderBigVGAN(object):
394
+ def __init__(self, ckpt_vocoder,device='cuda'):
395
+ vocoder_sd = torch.load(os.path.join(ckpt_vocoder,'best_netG.pt'), map_location='cpu')
396
+
397
+ vocoder_args = OmegaConf.load(os.path.join(ckpt_vocoder,'args.yml'))
398
+
399
+ self.generator = BigVGAN(vocoder_args)
400
+ self.generator.load_state_dict(vocoder_sd['generator'])
401
+ self.generator.eval()
402
+
403
+ self.device = device
404
+ self.generator.to(self.device)
405
+
406
+ def vocode(self, spec):
407
+ with torch.no_grad():
408
+ if isinstance(spec,np.ndarray):
409
+ spec = torch.from_numpy(spec).unsqueeze(0)
410
+ spec = spec.to(dtype=torch.float32,device=self.device)
411
+ return self.generator(spec).squeeze().cpu().numpy()
412
+
413
+ def __call__(self, wav):
414
+ return self.vocode(wav)