Spaces:
Running
on
Zero
Running
on
Zero
Upload 35 files
Browse files- app.py +141 -0
- campplus_cn_common.bin +3 -0
- configs/config_dit_mel_seed.yml +79 -0
- configs/hifigan.yml +25 -0
- hf_utils.py +12 -0
- modules/__pycache__/audio.cpython-310.pyc +0 -0
- modules/__pycache__/commons.cpython-310.pyc +0 -0
- modules/__pycache__/diffusion_transformer.cpython-310.pyc +0 -0
- modules/__pycache__/encodec.cpython-310.pyc +0 -0
- modules/__pycache__/flow_matching.cpython-310.pyc +0 -0
- modules/__pycache__/length_regulator.cpython-310.pyc +0 -0
- modules/__pycache__/wavenet.cpython-310.pyc +0 -0
- modules/audio.py +82 -0
- modules/campplus/DTDNN.py +115 -0
- modules/campplus/__pycache__/DTDNN.cpython-310.pyc +0 -0
- modules/campplus/__pycache__/layers.cpython-310.pyc +0 -0
- modules/campplus/classifier.py +70 -0
- modules/campplus/layers.py +253 -0
- modules/commons.py +452 -0
- modules/cosyvoice_tokenizer/__pycache__/frontend.cpython-310.pyc +0 -0
- modules/cosyvoice_tokenizer/frontend.py +54 -0
- modules/diffusion_transformer.py +237 -0
- modules/encodec.py +292 -0
- modules/flow_matching.py +153 -0
- modules/gpt_fast/__pycache__/model.cpython-310.pyc +0 -0
- modules/gpt_fast/generate.py +436 -0
- modules/gpt_fast/model.py +356 -0
- modules/gpt_fast/quantize.py +622 -0
- modules/hifigan/__pycache__/f0_predictor.cpython-310.pyc +0 -0
- modules/hifigan/__pycache__/generator.cpython-310.pyc +0 -0
- modules/hifigan/f0_predictor.py +55 -0
- modules/hifigan/generator.py +453 -0
- modules/layers.py +354 -0
- modules/length_regulator.py +42 -0
- modules/wavenet.py +174 -0
app.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
from modules.commons import build_model, load_checkpoint, recursive_munch
|
7 |
+
import yaml
|
8 |
+
from hf_utils import load_custom_model_from_hf
|
9 |
+
|
10 |
+
# Load model and configuration
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
|
13 |
+
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
|
14 |
+
"DiT_step_315000_seed_v2_online_pruned.pth",
|
15 |
+
"config_dit_mel_seed.yml")
|
16 |
+
|
17 |
+
config = yaml.safe_load(open(dit_config_path, 'r'))
|
18 |
+
model_params = recursive_munch(config['model_params'])
|
19 |
+
model = build_model(model_params, stage='DiT')
|
20 |
+
hop_length = config['preprocess_params']['spect_params']['hop_length']
|
21 |
+
sr = config['preprocess_params']['sr']
|
22 |
+
|
23 |
+
# Load checkpoints
|
24 |
+
model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
|
25 |
+
load_only_params=True, ignore_modules=[], is_distributed=False)
|
26 |
+
for key in model:
|
27 |
+
model[key].eval()
|
28 |
+
model[key].to(device)
|
29 |
+
model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
|
30 |
+
|
31 |
+
# Load additional modules
|
32 |
+
from modules.campplus.DTDNN import CAMPPlus
|
33 |
+
|
34 |
+
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
|
35 |
+
campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path']))
|
36 |
+
campplus_model.eval()
|
37 |
+
campplus_model.to(device)
|
38 |
+
|
39 |
+
from modules.hifigan.generator import HiFTGenerator
|
40 |
+
from modules.hifigan.f0_predictor import ConvRNNF0Predictor
|
41 |
+
|
42 |
+
hift_checkpoint_path, hift_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
|
43 |
+
"hift.pt",
|
44 |
+
"hifigan.yml")
|
45 |
+
hift_config = yaml.safe_load(open(hift_config_path, 'r'))
|
46 |
+
hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
|
47 |
+
hift_gen.load_state_dict(torch.load(hift_config['pretrained_model_path'], map_location='cpu'))
|
48 |
+
hift_gen.eval()
|
49 |
+
hift_gen.to(device)
|
50 |
+
|
51 |
+
from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
|
52 |
+
|
53 |
+
speech_tokenizer_path = load_custom_model_from_hf("Plachta/Seed-VC", "speech_tokenizer_v1.onnx", None)
|
54 |
+
|
55 |
+
cosyvoice_frontend = CosyVoiceFrontEnd(speech_tokenizer_model=speech_tokenizer_path,
|
56 |
+
device='cuda', device_id=0)
|
57 |
+
# Generate mel spectrograms
|
58 |
+
mel_fn_args = {
|
59 |
+
"n_fft": config['preprocess_params']['spect_params']['n_fft'],
|
60 |
+
"win_size": config['preprocess_params']['spect_params']['win_length'],
|
61 |
+
"hop_size": config['preprocess_params']['spect_params']['hop_length'],
|
62 |
+
"num_mels": config['preprocess_params']['spect_params']['n_mels'],
|
63 |
+
"sampling_rate": sr,
|
64 |
+
"fmin": 0,
|
65 |
+
"fmax": 8000,
|
66 |
+
"center": False
|
67 |
+
}
|
68 |
+
from modules.audio import mel_spectrogram
|
69 |
+
|
70 |
+
to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
|
71 |
+
|
72 |
+
@spaces.GPU
|
73 |
+
@torch.no_grad()
|
74 |
+
@torch.inference_mode()
|
75 |
+
def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate):
|
76 |
+
# Load audio
|
77 |
+
source_audio = librosa.load(source, sr=sr)[0]
|
78 |
+
ref_audio = librosa.load(target, sr=sr)[0]
|
79 |
+
|
80 |
+
# Process audio
|
81 |
+
source_audio = torch.tensor(source_audio[:sr * 30]).unsqueeze(0).float().to(device)
|
82 |
+
ref_audio = torch.tensor(ref_audio[:sr * 30]).unsqueeze(0).float().to(device)
|
83 |
+
|
84 |
+
# Resample
|
85 |
+
source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
|
86 |
+
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
|
87 |
+
|
88 |
+
# Extract features
|
89 |
+
S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
|
90 |
+
S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
|
91 |
+
|
92 |
+
mel = to_mel(source_audio.to(device).float())
|
93 |
+
mel2 = to_mel(ref_audio.to(device).float())
|
94 |
+
|
95 |
+
target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
|
96 |
+
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
|
97 |
+
|
98 |
+
# Style encoding
|
99 |
+
feat = torchaudio.compliance.kaldi.fbank(source_waves_16k,
|
100 |
+
num_mel_bins=80,
|
101 |
+
dither=0,
|
102 |
+
sample_frequency=16000)
|
103 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
104 |
+
style1 = campplus_model(feat.unsqueeze(0))
|
105 |
+
|
106 |
+
feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
|
107 |
+
num_mel_bins=80,
|
108 |
+
dither=0,
|
109 |
+
sample_frequency=16000)
|
110 |
+
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
|
111 |
+
style2 = campplus_model(feat2.unsqueeze(0))
|
112 |
+
|
113 |
+
# Length regulation
|
114 |
+
cond = model.length_regulator(S_alt, ylens=target_lengths)[0]
|
115 |
+
prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0]
|
116 |
+
cat_condition = torch.cat([prompt_condition, cond], dim=1)
|
117 |
+
|
118 |
+
# Voice Conversion
|
119 |
+
vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
|
120 |
+
mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
|
121 |
+
vc_target = vc_target[:, :, mel2.size(-1):]
|
122 |
+
|
123 |
+
# Convert to waveform
|
124 |
+
vc_wave = hift_gen.inference(vc_target)
|
125 |
+
|
126 |
+
return (sr, vc_wave.squeeze(0).cpu().numpy())
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
description = "Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) for details and updates."
|
131 |
+
inputs = [
|
132 |
+
gr.Audio(source="upload", type="filepath", label="Source Audio"),
|
133 |
+
gr.Audio(source="upload", type="filepath", label="Reference Audio"),
|
134 |
+
gr.Slider(minimum=1, maximum=1000, value=100, step=1, label="Diffusion Steps"),
|
135 |
+
gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust"),
|
136 |
+
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate"),
|
137 |
+
]
|
138 |
+
|
139 |
+
outputs = gr.Audio(label="Output Audio")
|
140 |
+
|
141 |
+
gr.Interface(fn=voice_conversion, description=description, inputs=inputs, outputs=outputs, title="Seed Voice Conversion").launch()
|
campplus_cn_common.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3388cf5fd3493c9ac9c69851d8e7a8badcfb4f3dc631020c4961371646d5ada8
|
3 |
+
size 28036335
|
configs/config_dit_mel_seed.yml
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
log_dir: "./runs/run_dit_mel_seed"
|
2 |
+
save_freq: 1
|
3 |
+
log_interval: 10
|
4 |
+
save_interval: 1000
|
5 |
+
device: "cuda"
|
6 |
+
epochs: 1000 # number of epochs for first stage training (pre-training)
|
7 |
+
batch_size: 4
|
8 |
+
batch_length: 100 # maximum duration of audio in a batch (in seconds)
|
9 |
+
max_len: 80 # maximum number of frames
|
10 |
+
pretrained_model: ""
|
11 |
+
pretrained_encoder: ""
|
12 |
+
load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
|
13 |
+
|
14 |
+
F0_path: "modules/JDC/bst.t7"
|
15 |
+
|
16 |
+
preprocess_params:
|
17 |
+
sr: 22050
|
18 |
+
spect_params:
|
19 |
+
n_fft: 1024
|
20 |
+
win_length: 1024
|
21 |
+
hop_length: 256
|
22 |
+
n_mels: 80
|
23 |
+
|
24 |
+
model_params:
|
25 |
+
dit_type: "DiT" # uDiT or DiT
|
26 |
+
reg_loss_type: "l2" # l1 or l2
|
27 |
+
|
28 |
+
speech_tokenizer:
|
29 |
+
path: "speech_tokenizer_v1.onnx"
|
30 |
+
|
31 |
+
style_encoder:
|
32 |
+
dim: 192
|
33 |
+
campplus_path: "campplus_cn_common.bin"
|
34 |
+
|
35 |
+
DAC:
|
36 |
+
encoder_dim: 64
|
37 |
+
encoder_rates: [2, 5, 5, 6]
|
38 |
+
decoder_dim: 1536
|
39 |
+
decoder_rates: [ 6, 5, 5, 2 ]
|
40 |
+
sr: 24000
|
41 |
+
|
42 |
+
length_regulator:
|
43 |
+
channels: 768
|
44 |
+
is_discrete: true
|
45 |
+
content_codebook_size: 4096
|
46 |
+
in_frame_rate: 50
|
47 |
+
out_frame_rate: 80
|
48 |
+
sampling_ratios: [1, 1, 1, 1]
|
49 |
+
|
50 |
+
DiT:
|
51 |
+
hidden_dim: 768
|
52 |
+
num_heads: 12
|
53 |
+
depth: 12
|
54 |
+
class_dropout_prob: 0.1
|
55 |
+
block_size: 4096
|
56 |
+
in_channels: 80
|
57 |
+
style_condition: true
|
58 |
+
final_layer_type: 'wavenet'
|
59 |
+
target: 'mel' # mel or codec
|
60 |
+
content_dim: 768
|
61 |
+
content_codebook_size: 1024
|
62 |
+
content_type: 'discrete'
|
63 |
+
f0_condition: false
|
64 |
+
n_f0_bins: 512
|
65 |
+
content_codebooks: 1
|
66 |
+
is_causal: false
|
67 |
+
long_skip_connection: true
|
68 |
+
zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
|
69 |
+
|
70 |
+
wavenet:
|
71 |
+
hidden_dim: 768
|
72 |
+
num_layers: 8
|
73 |
+
kernel_size: 5
|
74 |
+
dilation_rate: 1
|
75 |
+
p_dropout: 0.2
|
76 |
+
style_condition: true
|
77 |
+
|
78 |
+
loss_params:
|
79 |
+
base_lr: 0.0001
|
configs/hifigan.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hift:
|
2 |
+
in_channels: 80
|
3 |
+
base_channels: 512
|
4 |
+
nb_harmonics: 8
|
5 |
+
sampling_rate: 22050
|
6 |
+
nsf_alpha: 0.1
|
7 |
+
nsf_sigma: 0.003
|
8 |
+
nsf_voiced_threshold: 10
|
9 |
+
upsample_rates: [8, 8]
|
10 |
+
upsample_kernel_sizes: [16, 16]
|
11 |
+
istft_params:
|
12 |
+
n_fft: 16
|
13 |
+
hop_len: 4
|
14 |
+
resblock_kernel_sizes: [3, 7, 11]
|
15 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
16 |
+
source_resblock_kernel_sizes: [7, 11]
|
17 |
+
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
18 |
+
lrelu_slope: 0.1
|
19 |
+
audio_limit: 0.99
|
20 |
+
f0_predictor:
|
21 |
+
num_class: 1
|
22 |
+
in_channels: 80
|
23 |
+
cond_channels: 512
|
24 |
+
|
25 |
+
pretrained_model_path: "hift.pt"
|
hf_utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
|
4 |
+
|
5 |
+
def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"):
|
6 |
+
os.makedirs("./checkpoints", exist_ok=True)
|
7 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
|
8 |
+
if config_filename is None:
|
9 |
+
return model_path
|
10 |
+
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
|
11 |
+
|
12 |
+
return model_path, config_path
|
modules/__pycache__/audio.cpython-310.pyc
ADDED
Binary file (2.43 kB). View file
|
|
modules/__pycache__/commons.cpython-310.pyc
ADDED
Binary file (12.6 kB). View file
|
|
modules/__pycache__/diffusion_transformer.cpython-310.pyc
ADDED
Binary file (7.76 kB). View file
|
|
modules/__pycache__/encodec.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
modules/__pycache__/flow_matching.cpython-310.pyc
ADDED
Binary file (5.11 kB). View file
|
|
modules/__pycache__/length_regulator.cpython-310.pyc
ADDED
Binary file (1.58 kB). View file
|
|
modules/__pycache__/wavenet.cpython-310.pyc
ADDED
Binary file (5.15 kB). View file
|
|
modules/audio.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
46 |
+
if torch.min(y) < -1.0:
|
47 |
+
print("min value is ", torch.min(y))
|
48 |
+
if torch.max(y) > 1.0:
|
49 |
+
print("max value is ", torch.max(y))
|
50 |
+
|
51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
52 |
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
53 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
54 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
55 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
56 |
+
|
57 |
+
y = torch.nn.functional.pad(
|
58 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
59 |
+
)
|
60 |
+
y = y.squeeze(1)
|
61 |
+
|
62 |
+
spec = torch.view_as_real(
|
63 |
+
torch.stft(
|
64 |
+
y,
|
65 |
+
n_fft,
|
66 |
+
hop_length=hop_size,
|
67 |
+
win_length=win_size,
|
68 |
+
window=hann_window[str(y.device)],
|
69 |
+
center=center,
|
70 |
+
pad_mode="reflect",
|
71 |
+
normalized=False,
|
72 |
+
onesided=True,
|
73 |
+
return_complex=True,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
78 |
+
|
79 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
80 |
+
spec = spectral_normalize_torch(spec)
|
81 |
+
|
82 |
+
return spec
|
modules/campplus/DTDNN.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from modules.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear
|
11 |
+
|
12 |
+
|
13 |
+
class FCM(nn.Module):
|
14 |
+
def __init__(self,
|
15 |
+
block=BasicResBlock,
|
16 |
+
num_blocks=[2, 2],
|
17 |
+
m_channels=32,
|
18 |
+
feat_dim=80):
|
19 |
+
super(FCM, self).__init__()
|
20 |
+
self.in_planes = m_channels
|
21 |
+
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
22 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
23 |
+
|
24 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
25 |
+
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
|
26 |
+
|
27 |
+
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
|
28 |
+
self.bn2 = nn.BatchNorm2d(m_channels)
|
29 |
+
self.out_channels = m_channels * (feat_dim // 8)
|
30 |
+
|
31 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
32 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
33 |
+
layers = []
|
34 |
+
for stride in strides:
|
35 |
+
layers.append(block(self.in_planes, planes, stride))
|
36 |
+
self.in_planes = planes * block.expansion
|
37 |
+
return nn.Sequential(*layers)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = x.unsqueeze(1)
|
41 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
42 |
+
out = self.layer1(out)
|
43 |
+
out = self.layer2(out)
|
44 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
45 |
+
|
46 |
+
shape = out.shape
|
47 |
+
out = out.reshape(shape[0], shape[1]*shape[2], shape[3])
|
48 |
+
return out
|
49 |
+
|
50 |
+
class CAMPPlus(nn.Module):
|
51 |
+
def __init__(self,
|
52 |
+
feat_dim=80,
|
53 |
+
embedding_size=512,
|
54 |
+
growth_rate=32,
|
55 |
+
bn_size=4,
|
56 |
+
init_channels=128,
|
57 |
+
config_str='batchnorm-relu',
|
58 |
+
memory_efficient=True):
|
59 |
+
super(CAMPPlus, self).__init__()
|
60 |
+
|
61 |
+
self.head = FCM(feat_dim=feat_dim)
|
62 |
+
channels = self.head.out_channels
|
63 |
+
|
64 |
+
self.xvector = nn.Sequential(
|
65 |
+
OrderedDict([
|
66 |
+
|
67 |
+
('tdnn',
|
68 |
+
TDNNLayer(channels,
|
69 |
+
init_channels,
|
70 |
+
5,
|
71 |
+
stride=2,
|
72 |
+
dilation=1,
|
73 |
+
padding=-1,
|
74 |
+
config_str=config_str)),
|
75 |
+
]))
|
76 |
+
channels = init_channels
|
77 |
+
for i, (num_layers, kernel_size,
|
78 |
+
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
|
79 |
+
block = CAMDenseTDNNBlock(num_layers=num_layers,
|
80 |
+
in_channels=channels,
|
81 |
+
out_channels=growth_rate,
|
82 |
+
bn_channels=bn_size * growth_rate,
|
83 |
+
kernel_size=kernel_size,
|
84 |
+
dilation=dilation,
|
85 |
+
config_str=config_str,
|
86 |
+
memory_efficient=memory_efficient)
|
87 |
+
self.xvector.add_module('block%d' % (i + 1), block)
|
88 |
+
channels = channels + num_layers * growth_rate
|
89 |
+
self.xvector.add_module(
|
90 |
+
'transit%d' % (i + 1),
|
91 |
+
TransitLayer(channels,
|
92 |
+
channels // 2,
|
93 |
+
bias=False,
|
94 |
+
config_str=config_str))
|
95 |
+
channels //= 2
|
96 |
+
|
97 |
+
self.xvector.add_module(
|
98 |
+
'out_nonlinear', get_nonlinear(config_str, channels))
|
99 |
+
|
100 |
+
self.xvector.add_module('stats', StatsPool())
|
101 |
+
self.xvector.add_module(
|
102 |
+
'dense',
|
103 |
+
DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
|
104 |
+
|
105 |
+
for m in self.modules():
|
106 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
107 |
+
nn.init.kaiming_normal_(m.weight.data)
|
108 |
+
if m.bias is not None:
|
109 |
+
nn.init.zeros_(m.bias)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
113 |
+
x = self.head(x)
|
114 |
+
x = self.xvector(x)
|
115 |
+
return x
|
modules/campplus/__pycache__/DTDNN.cpython-310.pyc
ADDED
Binary file (3.45 kB). View file
|
|
modules/campplus/__pycache__/layers.cpython-310.pyc
ADDED
Binary file (7.3 kB). View file
|
|
modules/campplus/classifier.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from modules.campplus.layers import DenseLayer
|
9 |
+
|
10 |
+
|
11 |
+
class CosineClassifier(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
input_dim,
|
15 |
+
num_blocks=0,
|
16 |
+
inter_dim=512,
|
17 |
+
out_neurons=1000,
|
18 |
+
):
|
19 |
+
|
20 |
+
super().__init__()
|
21 |
+
self.blocks = nn.ModuleList()
|
22 |
+
|
23 |
+
for index in range(num_blocks):
|
24 |
+
self.blocks.append(
|
25 |
+
DenseLayer(input_dim, inter_dim, config_str='batchnorm')
|
26 |
+
)
|
27 |
+
input_dim = inter_dim
|
28 |
+
|
29 |
+
self.weight = nn.Parameter(
|
30 |
+
torch.FloatTensor(out_neurons, input_dim)
|
31 |
+
)
|
32 |
+
nn.init.xavier_uniform_(self.weight)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
# x: [B, dim]
|
36 |
+
for layer in self.blocks:
|
37 |
+
x = layer(x)
|
38 |
+
|
39 |
+
# normalized
|
40 |
+
x = F.linear(F.normalize(x), F.normalize(self.weight))
|
41 |
+
return x
|
42 |
+
|
43 |
+
class LinearClassifier(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
input_dim,
|
47 |
+
num_blocks=0,
|
48 |
+
inter_dim=512,
|
49 |
+
out_neurons=1000,
|
50 |
+
):
|
51 |
+
|
52 |
+
super().__init__()
|
53 |
+
self.blocks = nn.ModuleList()
|
54 |
+
|
55 |
+
self.nonlinear = nn.ReLU(inplace=True)
|
56 |
+
for index in range(num_blocks):
|
57 |
+
self.blocks.append(
|
58 |
+
DenseLayer(input_dim, inter_dim, bias=True)
|
59 |
+
)
|
60 |
+
input_dim = inter_dim
|
61 |
+
|
62 |
+
self.linear = nn.Linear(input_dim, out_neurons, bias=True)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
# x: [B, dim]
|
66 |
+
x = self.nonlinear(x)
|
67 |
+
for layer in self.blocks:
|
68 |
+
x = layer(x)
|
69 |
+
x = self.linear(x)
|
70 |
+
return x
|
modules/campplus/layers.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.utils.checkpoint as cp
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
def get_nonlinear(config_str, channels):
|
11 |
+
nonlinear = nn.Sequential()
|
12 |
+
for name in config_str.split('-'):
|
13 |
+
if name == 'relu':
|
14 |
+
nonlinear.add_module('relu', nn.ReLU(inplace=True))
|
15 |
+
elif name == 'prelu':
|
16 |
+
nonlinear.add_module('prelu', nn.PReLU(channels))
|
17 |
+
elif name == 'batchnorm':
|
18 |
+
nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
|
19 |
+
elif name == 'batchnorm_':
|
20 |
+
nonlinear.add_module('batchnorm',
|
21 |
+
nn.BatchNorm1d(channels, affine=False))
|
22 |
+
else:
|
23 |
+
raise ValueError('Unexpected module ({}).'.format(name))
|
24 |
+
return nonlinear
|
25 |
+
|
26 |
+
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
|
27 |
+
mean = x.mean(dim=dim)
|
28 |
+
std = x.std(dim=dim, unbiased=unbiased)
|
29 |
+
stats = torch.cat([mean, std], dim=-1)
|
30 |
+
if keepdim:
|
31 |
+
stats = stats.unsqueeze(dim=dim)
|
32 |
+
return stats
|
33 |
+
|
34 |
+
|
35 |
+
class StatsPool(nn.Module):
|
36 |
+
def forward(self, x):
|
37 |
+
return statistics_pooling(x)
|
38 |
+
|
39 |
+
|
40 |
+
class TDNNLayer(nn.Module):
|
41 |
+
def __init__(self,
|
42 |
+
in_channels,
|
43 |
+
out_channels,
|
44 |
+
kernel_size,
|
45 |
+
stride=1,
|
46 |
+
padding=0,
|
47 |
+
dilation=1,
|
48 |
+
bias=False,
|
49 |
+
config_str='batchnorm-relu'):
|
50 |
+
super(TDNNLayer, self).__init__()
|
51 |
+
if padding < 0:
|
52 |
+
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
53 |
+
kernel_size)
|
54 |
+
padding = (kernel_size - 1) // 2 * dilation
|
55 |
+
self.linear = nn.Conv1d(in_channels,
|
56 |
+
out_channels,
|
57 |
+
kernel_size,
|
58 |
+
stride=stride,
|
59 |
+
padding=padding,
|
60 |
+
dilation=dilation,
|
61 |
+
bias=bias)
|
62 |
+
self.nonlinear = get_nonlinear(config_str, out_channels)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x = self.linear(x)
|
66 |
+
x = self.nonlinear(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class CAMLayer(nn.Module):
|
71 |
+
def __init__(self,
|
72 |
+
bn_channels,
|
73 |
+
out_channels,
|
74 |
+
kernel_size,
|
75 |
+
stride,
|
76 |
+
padding,
|
77 |
+
dilation,
|
78 |
+
bias,
|
79 |
+
reduction=2):
|
80 |
+
super(CAMLayer, self).__init__()
|
81 |
+
self.linear_local = nn.Conv1d(bn_channels,
|
82 |
+
out_channels,
|
83 |
+
kernel_size,
|
84 |
+
stride=stride,
|
85 |
+
padding=padding,
|
86 |
+
dilation=dilation,
|
87 |
+
bias=bias)
|
88 |
+
self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
|
89 |
+
self.relu = nn.ReLU(inplace=True)
|
90 |
+
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
|
91 |
+
self.sigmoid = nn.Sigmoid()
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
y = self.linear_local(x)
|
95 |
+
context = x.mean(-1, keepdim=True)+self.seg_pooling(x)
|
96 |
+
context = self.relu(self.linear1(context))
|
97 |
+
m = self.sigmoid(self.linear2(context))
|
98 |
+
return y*m
|
99 |
+
|
100 |
+
def seg_pooling(self, x, seg_len=100, stype='avg'):
|
101 |
+
if stype == 'avg':
|
102 |
+
seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
103 |
+
elif stype == 'max':
|
104 |
+
seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
|
105 |
+
else:
|
106 |
+
raise ValueError('Wrong segment pooling type.')
|
107 |
+
shape = seg.shape
|
108 |
+
seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
|
109 |
+
seg = seg[..., :x.shape[-1]]
|
110 |
+
return seg
|
111 |
+
|
112 |
+
|
113 |
+
class CAMDenseTDNNLayer(nn.Module):
|
114 |
+
def __init__(self,
|
115 |
+
in_channels,
|
116 |
+
out_channels,
|
117 |
+
bn_channels,
|
118 |
+
kernel_size,
|
119 |
+
stride=1,
|
120 |
+
dilation=1,
|
121 |
+
bias=False,
|
122 |
+
config_str='batchnorm-relu',
|
123 |
+
memory_efficient=False):
|
124 |
+
super(CAMDenseTDNNLayer, self).__init__()
|
125 |
+
assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
|
126 |
+
kernel_size)
|
127 |
+
padding = (kernel_size - 1) // 2 * dilation
|
128 |
+
self.memory_efficient = memory_efficient
|
129 |
+
self.nonlinear1 = get_nonlinear(config_str, in_channels)
|
130 |
+
self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
|
131 |
+
self.nonlinear2 = get_nonlinear(config_str, bn_channels)
|
132 |
+
self.cam_layer = CAMLayer(bn_channels,
|
133 |
+
out_channels,
|
134 |
+
kernel_size,
|
135 |
+
stride=stride,
|
136 |
+
padding=padding,
|
137 |
+
dilation=dilation,
|
138 |
+
bias=bias)
|
139 |
+
|
140 |
+
def bn_function(self, x):
|
141 |
+
return self.linear1(self.nonlinear1(x))
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
if self.training and self.memory_efficient:
|
145 |
+
x = cp.checkpoint(self.bn_function, x)
|
146 |
+
else:
|
147 |
+
x = self.bn_function(x)
|
148 |
+
x = self.cam_layer(self.nonlinear2(x))
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class CAMDenseTDNNBlock(nn.ModuleList):
|
153 |
+
def __init__(self,
|
154 |
+
num_layers,
|
155 |
+
in_channels,
|
156 |
+
out_channels,
|
157 |
+
bn_channels,
|
158 |
+
kernel_size,
|
159 |
+
stride=1,
|
160 |
+
dilation=1,
|
161 |
+
bias=False,
|
162 |
+
config_str='batchnorm-relu',
|
163 |
+
memory_efficient=False):
|
164 |
+
super(CAMDenseTDNNBlock, self).__init__()
|
165 |
+
for i in range(num_layers):
|
166 |
+
layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
|
167 |
+
out_channels=out_channels,
|
168 |
+
bn_channels=bn_channels,
|
169 |
+
kernel_size=kernel_size,
|
170 |
+
stride=stride,
|
171 |
+
dilation=dilation,
|
172 |
+
bias=bias,
|
173 |
+
config_str=config_str,
|
174 |
+
memory_efficient=memory_efficient)
|
175 |
+
self.add_module('tdnnd%d' % (i + 1), layer)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
for layer in self:
|
179 |
+
x = torch.cat([x, layer(x)], dim=1)
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class TransitLayer(nn.Module):
|
184 |
+
def __init__(self,
|
185 |
+
in_channels,
|
186 |
+
out_channels,
|
187 |
+
bias=True,
|
188 |
+
config_str='batchnorm-relu'):
|
189 |
+
super(TransitLayer, self).__init__()
|
190 |
+
self.nonlinear = get_nonlinear(config_str, in_channels)
|
191 |
+
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
x = self.nonlinear(x)
|
195 |
+
x = self.linear(x)
|
196 |
+
return x
|
197 |
+
|
198 |
+
|
199 |
+
class DenseLayer(nn.Module):
|
200 |
+
def __init__(self,
|
201 |
+
in_channels,
|
202 |
+
out_channels,
|
203 |
+
bias=False,
|
204 |
+
config_str='batchnorm-relu'):
|
205 |
+
super(DenseLayer, self).__init__()
|
206 |
+
self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
|
207 |
+
self.nonlinear = get_nonlinear(config_str, out_channels)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
if len(x.shape) == 2:
|
211 |
+
x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
|
212 |
+
else:
|
213 |
+
x = self.linear(x)
|
214 |
+
x = self.nonlinear(x)
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class BasicResBlock(nn.Module):
|
219 |
+
expansion = 1
|
220 |
+
|
221 |
+
def __init__(self, in_planes, planes, stride=1):
|
222 |
+
super(BasicResBlock, self).__init__()
|
223 |
+
self.conv1 = nn.Conv2d(in_planes,
|
224 |
+
planes,
|
225 |
+
kernel_size=3,
|
226 |
+
stride=(stride, 1),
|
227 |
+
padding=1,
|
228 |
+
bias=False)
|
229 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
230 |
+
self.conv2 = nn.Conv2d(planes,
|
231 |
+
planes,
|
232 |
+
kernel_size=3,
|
233 |
+
stride=1,
|
234 |
+
padding=1,
|
235 |
+
bias=False)
|
236 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
237 |
+
|
238 |
+
self.shortcut = nn.Sequential()
|
239 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
240 |
+
self.shortcut = nn.Sequential(
|
241 |
+
nn.Conv2d(in_planes,
|
242 |
+
self.expansion * planes,
|
243 |
+
kernel_size=1,
|
244 |
+
stride=(stride, 1),
|
245 |
+
bias=False),
|
246 |
+
nn.BatchNorm2d(self.expansion * planes))
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
250 |
+
out = self.bn2(self.conv2(out))
|
251 |
+
out += self.shortcut(x)
|
252 |
+
out = F.relu(out)
|
253 |
+
return out
|
modules/commons.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from munch import Munch
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
class AttrDict(dict):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
13 |
+
self.__dict__ = self
|
14 |
+
|
15 |
+
|
16 |
+
def init_weights(m, mean=0.0, std=0.01):
|
17 |
+
classname = m.__class__.__name__
|
18 |
+
if classname.find("Conv") != -1:
|
19 |
+
m.weight.data.normal_(mean, std)
|
20 |
+
|
21 |
+
|
22 |
+
def get_padding(kernel_size, dilation=1):
|
23 |
+
return int((kernel_size * dilation - dilation) / 2)
|
24 |
+
|
25 |
+
|
26 |
+
def convert_pad_shape(pad_shape):
|
27 |
+
l = pad_shape[::-1]
|
28 |
+
pad_shape = [item for sublist in l for item in sublist]
|
29 |
+
return pad_shape
|
30 |
+
|
31 |
+
|
32 |
+
def intersperse(lst, item):
|
33 |
+
result = [item] * (len(lst) * 2 + 1)
|
34 |
+
result[1::2] = lst
|
35 |
+
return result
|
36 |
+
|
37 |
+
|
38 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
39 |
+
"""KL(P||Q)"""
|
40 |
+
kl = (logs_q - logs_p) - 0.5
|
41 |
+
kl += (
|
42 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
43 |
+
)
|
44 |
+
return kl
|
45 |
+
|
46 |
+
|
47 |
+
def rand_gumbel(shape):
|
48 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
49 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
50 |
+
return -torch.log(-torch.log(uniform_samples))
|
51 |
+
|
52 |
+
|
53 |
+
def rand_gumbel_like(x):
|
54 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
55 |
+
return g
|
56 |
+
|
57 |
+
|
58 |
+
def slice_segments(x, ids_str, segment_size=4):
|
59 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
60 |
+
for i in range(x.size(0)):
|
61 |
+
idx_str = ids_str[i]
|
62 |
+
idx_end = idx_str + segment_size
|
63 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
64 |
+
return ret
|
65 |
+
|
66 |
+
|
67 |
+
def slice_segments_audio(x, ids_str, segment_size=4):
|
68 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
69 |
+
for i in range(x.size(0)):
|
70 |
+
idx_str = ids_str[i]
|
71 |
+
idx_end = idx_str + segment_size
|
72 |
+
ret[i] = x[i, idx_str:idx_end]
|
73 |
+
return ret
|
74 |
+
|
75 |
+
|
76 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
77 |
+
b, d, t = x.size()
|
78 |
+
if x_lengths is None:
|
79 |
+
x_lengths = t
|
80 |
+
ids_str_max = x_lengths - segment_size + 1
|
81 |
+
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
|
82 |
+
dtype=torch.long
|
83 |
+
)
|
84 |
+
ret = slice_segments(x, ids_str, segment_size)
|
85 |
+
return ret, ids_str
|
86 |
+
|
87 |
+
|
88 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
89 |
+
position = torch.arange(length, dtype=torch.float)
|
90 |
+
num_timescales = channels // 2
|
91 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
92 |
+
num_timescales - 1
|
93 |
+
)
|
94 |
+
inv_timescales = min_timescale * torch.exp(
|
95 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
96 |
+
)
|
97 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
98 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
99 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
100 |
+
signal = signal.view(1, channels, length)
|
101 |
+
return signal
|
102 |
+
|
103 |
+
|
104 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
105 |
+
b, channels, length = x.size()
|
106 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
107 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
108 |
+
|
109 |
+
|
110 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
111 |
+
b, channels, length = x.size()
|
112 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
113 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
114 |
+
|
115 |
+
|
116 |
+
def subsequent_mask(length):
|
117 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
118 |
+
return mask
|
119 |
+
|
120 |
+
|
121 |
+
@torch.jit.script
|
122 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
123 |
+
n_channels_int = n_channels[0]
|
124 |
+
in_act = input_a + input_b
|
125 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
126 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
127 |
+
acts = t_act * s_act
|
128 |
+
return acts
|
129 |
+
|
130 |
+
|
131 |
+
def convert_pad_shape(pad_shape):
|
132 |
+
l = pad_shape[::-1]
|
133 |
+
pad_shape = [item for sublist in l for item in sublist]
|
134 |
+
return pad_shape
|
135 |
+
|
136 |
+
|
137 |
+
def shift_1d(x):
|
138 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
def sequence_mask(length, max_length=None):
|
143 |
+
if max_length is None:
|
144 |
+
max_length = length.max()
|
145 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
146 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
147 |
+
|
148 |
+
|
149 |
+
def avg_with_mask(x, mask):
|
150 |
+
assert mask.dtype == torch.float, "Mask should be float"
|
151 |
+
|
152 |
+
if mask.ndim == 2:
|
153 |
+
mask = mask.unsqueeze(1)
|
154 |
+
|
155 |
+
if mask.shape[1] == 1:
|
156 |
+
mask = mask.expand_as(x)
|
157 |
+
|
158 |
+
return (x * mask).sum() / mask.sum()
|
159 |
+
|
160 |
+
|
161 |
+
def generate_path(duration, mask):
|
162 |
+
"""
|
163 |
+
duration: [b, 1, t_x]
|
164 |
+
mask: [b, 1, t_y, t_x]
|
165 |
+
"""
|
166 |
+
device = duration.device
|
167 |
+
|
168 |
+
b, _, t_y, t_x = mask.shape
|
169 |
+
cum_duration = torch.cumsum(duration, -1)
|
170 |
+
|
171 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
172 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
173 |
+
path = path.view(b, t_x, t_y)
|
174 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
175 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
176 |
+
return path
|
177 |
+
|
178 |
+
|
179 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
180 |
+
if isinstance(parameters, torch.Tensor):
|
181 |
+
parameters = [parameters]
|
182 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
183 |
+
norm_type = float(norm_type)
|
184 |
+
if clip_value is not None:
|
185 |
+
clip_value = float(clip_value)
|
186 |
+
|
187 |
+
total_norm = 0
|
188 |
+
for p in parameters:
|
189 |
+
param_norm = p.grad.data.norm(norm_type)
|
190 |
+
total_norm += param_norm.item() ** norm_type
|
191 |
+
if clip_value is not None:
|
192 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
193 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
194 |
+
return total_norm
|
195 |
+
|
196 |
+
|
197 |
+
def log_norm(x, mean=-4, std=4, dim=2):
|
198 |
+
"""
|
199 |
+
normalized log mel -> mel -> norm -> log(norm)
|
200 |
+
"""
|
201 |
+
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
def load_F0_models(path):
|
206 |
+
# load F0 model
|
207 |
+
from .JDC.model import JDCNet
|
208 |
+
|
209 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
210 |
+
params = torch.load(path, map_location="cpu")["net"]
|
211 |
+
F0_model.load_state_dict(params)
|
212 |
+
_ = F0_model.train()
|
213 |
+
|
214 |
+
return F0_model
|
215 |
+
|
216 |
+
|
217 |
+
def modify_w2v_forward(self, output_layer=15):
|
218 |
+
"""
|
219 |
+
change forward method of w2v encoder to get its intermediate layer output
|
220 |
+
:param self:
|
221 |
+
:param layer:
|
222 |
+
:return:
|
223 |
+
"""
|
224 |
+
from transformers.modeling_outputs import BaseModelOutput
|
225 |
+
|
226 |
+
def forward(
|
227 |
+
hidden_states,
|
228 |
+
attention_mask=None,
|
229 |
+
output_attentions=False,
|
230 |
+
output_hidden_states=False,
|
231 |
+
return_dict=True,
|
232 |
+
):
|
233 |
+
all_hidden_states = () if output_hidden_states else None
|
234 |
+
all_self_attentions = () if output_attentions else None
|
235 |
+
|
236 |
+
conv_attention_mask = attention_mask
|
237 |
+
if attention_mask is not None:
|
238 |
+
# make sure padded tokens output 0
|
239 |
+
hidden_states = hidden_states.masked_fill(
|
240 |
+
~attention_mask.bool().unsqueeze(-1), 0.0
|
241 |
+
)
|
242 |
+
|
243 |
+
# extend attention_mask
|
244 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(
|
245 |
+
dtype=hidden_states.dtype
|
246 |
+
)
|
247 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
248 |
+
attention_mask = attention_mask.expand(
|
249 |
+
attention_mask.shape[0],
|
250 |
+
1,
|
251 |
+
attention_mask.shape[-1],
|
252 |
+
attention_mask.shape[-1],
|
253 |
+
)
|
254 |
+
|
255 |
+
hidden_states = self.dropout(hidden_states)
|
256 |
+
|
257 |
+
if self.embed_positions is not None:
|
258 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
259 |
+
else:
|
260 |
+
relative_position_embeddings = None
|
261 |
+
|
262 |
+
deepspeed_zero3_is_enabled = False
|
263 |
+
|
264 |
+
for i, layer in enumerate(self.layers):
|
265 |
+
if output_hidden_states:
|
266 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
267 |
+
|
268 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
269 |
+
dropout_probability = torch.rand([])
|
270 |
+
|
271 |
+
skip_the_layer = (
|
272 |
+
True
|
273 |
+
if self.training and (dropout_probability < self.config.layerdrop)
|
274 |
+
else False
|
275 |
+
)
|
276 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
277 |
+
# under deepspeed zero3 all gpus must run in sync
|
278 |
+
if self.gradient_checkpointing and self.training:
|
279 |
+
layer_outputs = self._gradient_checkpointing_func(
|
280 |
+
layer.__call__,
|
281 |
+
hidden_states,
|
282 |
+
attention_mask,
|
283 |
+
relative_position_embeddings,
|
284 |
+
output_attentions,
|
285 |
+
conv_attention_mask,
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
layer_outputs = layer(
|
289 |
+
hidden_states,
|
290 |
+
attention_mask=attention_mask,
|
291 |
+
relative_position_embeddings=relative_position_embeddings,
|
292 |
+
output_attentions=output_attentions,
|
293 |
+
conv_attention_mask=conv_attention_mask,
|
294 |
+
)
|
295 |
+
hidden_states = layer_outputs[0]
|
296 |
+
|
297 |
+
if skip_the_layer:
|
298 |
+
layer_outputs = (None, None)
|
299 |
+
|
300 |
+
if output_attentions:
|
301 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
302 |
+
|
303 |
+
if i == output_layer - 1:
|
304 |
+
break
|
305 |
+
|
306 |
+
if output_hidden_states:
|
307 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
308 |
+
|
309 |
+
if not return_dict:
|
310 |
+
return tuple(
|
311 |
+
v
|
312 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
313 |
+
if v is not None
|
314 |
+
)
|
315 |
+
return BaseModelOutput(
|
316 |
+
last_hidden_state=hidden_states,
|
317 |
+
hidden_states=all_hidden_states,
|
318 |
+
attentions=all_self_attentions,
|
319 |
+
)
|
320 |
+
|
321 |
+
return forward
|
322 |
+
|
323 |
+
|
324 |
+
MATPLOTLIB_FLAG = False
|
325 |
+
|
326 |
+
|
327 |
+
def plot_spectrogram_to_numpy(spectrogram):
|
328 |
+
global MATPLOTLIB_FLAG
|
329 |
+
if not MATPLOTLIB_FLAG:
|
330 |
+
import matplotlib
|
331 |
+
import logging
|
332 |
+
|
333 |
+
matplotlib.use("Agg")
|
334 |
+
MATPLOTLIB_FLAG = True
|
335 |
+
mpl_logger = logging.getLogger("matplotlib")
|
336 |
+
mpl_logger.setLevel(logging.WARNING)
|
337 |
+
import matplotlib.pylab as plt
|
338 |
+
import numpy as np
|
339 |
+
|
340 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
341 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
342 |
+
plt.colorbar(im, ax=ax)
|
343 |
+
plt.xlabel("Frames")
|
344 |
+
plt.ylabel("Channels")
|
345 |
+
plt.tight_layout()
|
346 |
+
|
347 |
+
fig.canvas.draw()
|
348 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
349 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
350 |
+
plt.close()
|
351 |
+
return data
|
352 |
+
|
353 |
+
|
354 |
+
def normalize_f0(f0_sequence):
|
355 |
+
# Remove unvoiced frames (replace with -1)
|
356 |
+
voiced_indices = np.where(f0_sequence > 0)[0]
|
357 |
+
f0_voiced = f0_sequence[voiced_indices]
|
358 |
+
|
359 |
+
# Convert to log scale
|
360 |
+
log_f0 = np.log2(f0_voiced)
|
361 |
+
|
362 |
+
# Calculate mean and standard deviation
|
363 |
+
mean_f0 = np.mean(log_f0)
|
364 |
+
std_f0 = np.std(log_f0)
|
365 |
+
|
366 |
+
# Normalize the F0 sequence
|
367 |
+
normalized_f0 = (log_f0 - mean_f0) / std_f0
|
368 |
+
|
369 |
+
# Create the normalized F0 sequence with unvoiced frames
|
370 |
+
normalized_sequence = np.zeros_like(f0_sequence)
|
371 |
+
normalized_sequence[voiced_indices] = normalized_f0
|
372 |
+
normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
|
373 |
+
|
374 |
+
return normalized_sequence
|
375 |
+
|
376 |
+
|
377 |
+
def build_model(args, stage="DiT"):
|
378 |
+
if stage == "DiT":
|
379 |
+
from modules.flow_matching import CFM
|
380 |
+
from modules.length_regulator import InterpolateRegulator
|
381 |
+
|
382 |
+
length_regulator = InterpolateRegulator(
|
383 |
+
channels=args.length_regulator.channels,
|
384 |
+
sampling_ratios=args.length_regulator.sampling_ratios,
|
385 |
+
is_discrete=args.length_regulator.is_discrete,
|
386 |
+
codebook_size=args.length_regulator.content_codebook_size,
|
387 |
+
)
|
388 |
+
cfm = CFM(args)
|
389 |
+
nets = Munch(
|
390 |
+
cfm=cfm,
|
391 |
+
length_regulator=length_regulator,
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
raise ValueError(f"Unknown stage: {stage}")
|
395 |
+
|
396 |
+
return nets
|
397 |
+
|
398 |
+
|
399 |
+
def load_checkpoint(
|
400 |
+
model,
|
401 |
+
optimizer,
|
402 |
+
path,
|
403 |
+
load_only_params=True,
|
404 |
+
ignore_modules=[],
|
405 |
+
is_distributed=False,
|
406 |
+
):
|
407 |
+
state = torch.load(path, map_location="cpu")
|
408 |
+
params = state["net"]
|
409 |
+
for key in model:
|
410 |
+
if key in params and key not in ignore_modules:
|
411 |
+
if not is_distributed:
|
412 |
+
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
413 |
+
for k in list(params[key].keys()):
|
414 |
+
if k.startswith("module."):
|
415 |
+
params[key][k[len("module.") :]] = params[key][k]
|
416 |
+
del params[key][k]
|
417 |
+
model_state_dict = model[key].state_dict()
|
418 |
+
# 过滤出形状匹配的键值对
|
419 |
+
filtered_state_dict = {
|
420 |
+
k: v
|
421 |
+
for k, v in params[key].items()
|
422 |
+
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
423 |
+
}
|
424 |
+
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
|
425 |
+
if skipped_keys:
|
426 |
+
print(
|
427 |
+
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
428 |
+
)
|
429 |
+
print("%s loaded" % key)
|
430 |
+
model[key].load_state_dict(filtered_state_dict, strict=False)
|
431 |
+
_ = [model[key].eval() for key in model]
|
432 |
+
|
433 |
+
if not load_only_params:
|
434 |
+
epoch = state["epoch"] + 1
|
435 |
+
iters = state["iters"]
|
436 |
+
optimizer.load_state_dict(state["optimizer"])
|
437 |
+
optimizer.load_scheduler_state_dict(state["scheduler"])
|
438 |
+
|
439 |
+
else:
|
440 |
+
epoch = 0
|
441 |
+
iters = 0
|
442 |
+
|
443 |
+
return model, optimizer, epoch, iters
|
444 |
+
|
445 |
+
|
446 |
+
def recursive_munch(d):
|
447 |
+
if isinstance(d, dict):
|
448 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
449 |
+
elif isinstance(d, list):
|
450 |
+
return [recursive_munch(v) for v in d]
|
451 |
+
else:
|
452 |
+
return d
|
modules/cosyvoice_tokenizer/__pycache__/frontend.cpython-310.pyc
ADDED
Binary file (2.56 kB). View file
|
|
modules/cosyvoice_tokenizer/frontend.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
import onnxruntime
|
16 |
+
import torch
|
17 |
+
import numpy as np
|
18 |
+
import whisper
|
19 |
+
import torchaudio.compliance.kaldi as kaldi
|
20 |
+
|
21 |
+
class CosyVoiceFrontEnd:
|
22 |
+
|
23 |
+
def __init__(self, speech_tokenizer_model: str, device: str = 'cuda', device_id: int = 0):
|
24 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
25 |
+
option = onnxruntime.SessionOptions()
|
26 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
27 |
+
option.intra_op_num_threads = 1
|
28 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if device == "cuda" else "CPUExecutionProvider"])
|
29 |
+
if device == 'cuda':
|
30 |
+
self.speech_tokenizer_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': device_id}])
|
31 |
+
|
32 |
+
def extract_speech_token(self, speech):
|
33 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
34 |
+
speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
|
35 |
+
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
36 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
37 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
38 |
+
return speech_token, speech_token_len
|
39 |
+
|
40 |
+
def _extract_spk_embedding(self, speech):
|
41 |
+
feat = kaldi.fbank(speech,
|
42 |
+
num_mel_bins=80,
|
43 |
+
dither=0,
|
44 |
+
sample_frequency=16000)
|
45 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
46 |
+
embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
47 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
48 |
+
return embedding
|
49 |
+
|
50 |
+
def _extract_speech_feat(self, speech):
|
51 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
52 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
53 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
54 |
+
return speech_feat, speech_feat_len
|
modules/diffusion_transformer.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
from modules.gpt_fast.model import ModelArgs, Transformer
|
6 |
+
from modules.wavenet import WN
|
7 |
+
from modules.commons import sequence_mask
|
8 |
+
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
def modulate(x, shift, scale):
|
12 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
13 |
+
|
14 |
+
|
15 |
+
#################################################################################
|
16 |
+
# Embedding Layers for Timesteps and Class Labels #
|
17 |
+
#################################################################################
|
18 |
+
|
19 |
+
class TimestepEmbedder(nn.Module):
|
20 |
+
"""
|
21 |
+
Embeds scalar timesteps into vector representations.
|
22 |
+
"""
|
23 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
24 |
+
super().__init__()
|
25 |
+
self.mlp = nn.Sequential(
|
26 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
27 |
+
nn.SiLU(),
|
28 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
29 |
+
)
|
30 |
+
self.frequency_embedding_size = frequency_embedding_size
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def timestep_embedding(t, dim, max_period=10000, scale=1000):
|
34 |
+
"""
|
35 |
+
Create sinusoidal timestep embeddings.
|
36 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
37 |
+
These may be fractional.
|
38 |
+
:param dim: the dimension of the output.
|
39 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
40 |
+
:return: an (N, D) Tensor of positional embeddings.
|
41 |
+
"""
|
42 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
43 |
+
half = dim // 2
|
44 |
+
freqs = torch.exp(
|
45 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
46 |
+
).to(device=t.device)
|
47 |
+
args = scale * t[:, None].float() * freqs[None]
|
48 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
49 |
+
if dim % 2:
|
50 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
51 |
+
return embedding
|
52 |
+
|
53 |
+
def forward(self, t):
|
54 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
55 |
+
t_emb = self.mlp(t_freq)
|
56 |
+
return t_emb
|
57 |
+
|
58 |
+
|
59 |
+
class StyleEmbedder(nn.Module):
|
60 |
+
"""
|
61 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
62 |
+
"""
|
63 |
+
def __init__(self, input_size, hidden_size, dropout_prob):
|
64 |
+
super().__init__()
|
65 |
+
use_cfg_embedding = dropout_prob > 0
|
66 |
+
self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
|
67 |
+
self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
|
68 |
+
self.input_size = input_size
|
69 |
+
self.dropout_prob = dropout_prob
|
70 |
+
|
71 |
+
def forward(self, labels, train, force_drop_ids=None):
|
72 |
+
use_dropout = self.dropout_prob > 0
|
73 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
74 |
+
labels = self.token_drop(labels, force_drop_ids)
|
75 |
+
else:
|
76 |
+
labels = self.style_in(labels)
|
77 |
+
embeddings = labels
|
78 |
+
return embeddings
|
79 |
+
|
80 |
+
class FinalLayer(nn.Module):
|
81 |
+
"""
|
82 |
+
The final layer of DiT.
|
83 |
+
"""
|
84 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
85 |
+
super().__init__()
|
86 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
87 |
+
self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
|
88 |
+
self.adaLN_modulation = nn.Sequential(
|
89 |
+
nn.SiLU(),
|
90 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(self, x, c):
|
94 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
95 |
+
x = modulate(self.norm_final(x), shift, scale)
|
96 |
+
x = self.linear(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
class DiT(torch.nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
args
|
103 |
+
):
|
104 |
+
super(DiT, self).__init__()
|
105 |
+
self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
|
106 |
+
self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
|
107 |
+
self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
|
108 |
+
model_args = ModelArgs(
|
109 |
+
block_size=args.DiT.block_size,
|
110 |
+
n_layer=args.DiT.depth,
|
111 |
+
n_head=args.DiT.num_heads,
|
112 |
+
dim=args.DiT.hidden_dim,
|
113 |
+
head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
|
114 |
+
vocab_size=1024,
|
115 |
+
uvit_skip_connection=self.uvit_skip_connection,
|
116 |
+
)
|
117 |
+
self.transformer = Transformer(model_args)
|
118 |
+
self.in_channels = args.DiT.in_channels
|
119 |
+
self.out_channels = args.DiT.in_channels
|
120 |
+
self.num_heads = args.DiT.num_heads
|
121 |
+
|
122 |
+
self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
|
123 |
+
|
124 |
+
self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
|
125 |
+
self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
|
126 |
+
self.content_dim = args.DiT.content_dim # for continuous content
|
127 |
+
self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
|
128 |
+
self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
|
129 |
+
|
130 |
+
self.is_causal = args.DiT.is_causal
|
131 |
+
|
132 |
+
self.n_f0_bins = args.DiT.n_f0_bins
|
133 |
+
self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
|
134 |
+
self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
|
135 |
+
self.f0_condition = args.DiT.f0_condition
|
136 |
+
|
137 |
+
self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
|
138 |
+
self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
|
139 |
+
# self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
|
140 |
+
# self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
|
141 |
+
|
142 |
+
input_pos = torch.arange(args.DiT.block_size)
|
143 |
+
self.register_buffer("input_pos", input_pos)
|
144 |
+
|
145 |
+
self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
|
146 |
+
self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
|
147 |
+
self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
|
148 |
+
if self.final_layer_type == 'wavenet':
|
149 |
+
self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
|
150 |
+
kernel_size=args.wavenet.kernel_size,
|
151 |
+
dilation_rate=args.wavenet.dilation_rate,
|
152 |
+
n_layers=args.wavenet.num_layers,
|
153 |
+
gin_channels=args.wavenet.hidden_dim,
|
154 |
+
p_dropout=args.wavenet.p_dropout,
|
155 |
+
causal=False)
|
156 |
+
self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
|
157 |
+
else:
|
158 |
+
self.final_mlp = nn.Sequential(
|
159 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
|
160 |
+
nn.SiLU(),
|
161 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
|
162 |
+
)
|
163 |
+
self.final_conv = nn.Conv1d(args.DiT.in_channels, args.DiT.in_channels, kernel_size=3, padding=1)
|
164 |
+
self.transformer_style_condition = args.DiT.style_condition
|
165 |
+
self.wavenet_style_condition = args.wavenet.style_condition
|
166 |
+
assert args.DiT.style_condition == args.wavenet.style_condition
|
167 |
+
|
168 |
+
self.class_dropout_prob = args.DiT.class_dropout_prob
|
169 |
+
self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
|
170 |
+
self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
|
171 |
+
self.long_skip_connection = args.DiT.long_skip_connection
|
172 |
+
self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
|
173 |
+
|
174 |
+
self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
|
175 |
+
args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
|
176 |
+
args.DiT.hidden_dim)
|
177 |
+
if self.style_as_token:
|
178 |
+
self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
|
179 |
+
|
180 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
181 |
+
self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
|
182 |
+
def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
|
183 |
+
class_dropout = False
|
184 |
+
if self.training and torch.rand(1) < self.class_dropout_prob:
|
185 |
+
class_dropout = True
|
186 |
+
if not self.training and mask_content:
|
187 |
+
class_dropout = True
|
188 |
+
# cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
|
189 |
+
cond_in_module = self.cond_projection
|
190 |
+
|
191 |
+
B, _, T = x.size()
|
192 |
+
|
193 |
+
|
194 |
+
t1 = self.t_embedder(t) # (N, D)
|
195 |
+
|
196 |
+
cond = cond_in_module(cond)
|
197 |
+
if self.f0_condition and f0 is not None:
|
198 |
+
quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
|
199 |
+
cond = cond + self.f0_embedder(quantized_f0)
|
200 |
+
|
201 |
+
x = x.transpose(1, 2)
|
202 |
+
prompt_x = prompt_x.transpose(1, 2)
|
203 |
+
|
204 |
+
x_in = torch.cat([x, prompt_x, cond], dim=-1)
|
205 |
+
if self.transformer_style_condition and not self.style_as_token:
|
206 |
+
x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
|
207 |
+
if class_dropout:
|
208 |
+
x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
|
209 |
+
x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
|
210 |
+
|
211 |
+
if self.style_as_token:
|
212 |
+
style = self.style_in(style)
|
213 |
+
style = torch.zeros_like(style) if class_dropout else style
|
214 |
+
x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
|
215 |
+
if self.time_as_token:
|
216 |
+
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
|
217 |
+
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
|
218 |
+
input_pos = self.input_pos[:x_in.size(1)] # (T,)
|
219 |
+
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
|
220 |
+
x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
|
221 |
+
x_res = x_res[:, 1:] if self.time_as_token else x_res
|
222 |
+
x_res = x_res[:, 1:] if self.style_as_token else x_res
|
223 |
+
if self.long_skip_connection:
|
224 |
+
x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
|
225 |
+
if self.final_layer_type == 'wavenet':
|
226 |
+
x = self.conv1(x_res)
|
227 |
+
x = x.transpose(1, 2)
|
228 |
+
t2 = self.t_embedder2(t)
|
229 |
+
x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
|
230 |
+
x_res) # long residual connection
|
231 |
+
x = self.final_layer(x, t1).transpose(1, 2)
|
232 |
+
x = self.conv2(x)
|
233 |
+
else:
|
234 |
+
x = self.final_mlp(x_res)
|
235 |
+
x = x.transpose(1, 2)
|
236 |
+
x = self.final_conv(x)
|
237 |
+
return x
|
modules/encodec.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Convolutional layers wrappers and utilities."""
|
8 |
+
|
9 |
+
import math
|
10 |
+
import typing as tp
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
17 |
+
|
18 |
+
import typing as tp
|
19 |
+
|
20 |
+
import einops
|
21 |
+
|
22 |
+
|
23 |
+
class ConvLayerNorm(nn.LayerNorm):
|
24 |
+
"""
|
25 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
26 |
+
before running the normalization and moves them back to original position right after.
|
27 |
+
"""
|
28 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
29 |
+
super().__init__(normalized_shape, **kwargs)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = einops.rearrange(x, 'b ... t -> b t ...')
|
33 |
+
x = super().forward(x)
|
34 |
+
x = einops.rearrange(x, 'b t ... -> b ... t')
|
35 |
+
return
|
36 |
+
|
37 |
+
|
38 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
39 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
40 |
+
|
41 |
+
|
42 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
43 |
+
assert norm in CONV_NORMALIZATIONS
|
44 |
+
if norm == 'weight_norm':
|
45 |
+
return weight_norm(module)
|
46 |
+
elif norm == 'spectral_norm':
|
47 |
+
return spectral_norm(module)
|
48 |
+
else:
|
49 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
50 |
+
# doesn't need reparametrization.
|
51 |
+
return module
|
52 |
+
|
53 |
+
|
54 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
55 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
56 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
57 |
+
"""
|
58 |
+
assert norm in CONV_NORMALIZATIONS
|
59 |
+
if norm == 'layer_norm':
|
60 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
61 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
62 |
+
elif norm == 'time_group_norm':
|
63 |
+
if causal:
|
64 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
65 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
66 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
67 |
+
else:
|
68 |
+
return nn.Identity()
|
69 |
+
|
70 |
+
|
71 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
72 |
+
padding_total: int = 0) -> int:
|
73 |
+
"""See `pad_for_conv1d`.
|
74 |
+
"""
|
75 |
+
length = x.shape[-1]
|
76 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
77 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
78 |
+
return ideal_length - length
|
79 |
+
|
80 |
+
|
81 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
82 |
+
"""Pad for a convolution to make sure that the last window is full.
|
83 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
84 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
85 |
+
might get removed.
|
86 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
87 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
88 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
89 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
90 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
91 |
+
"""
|
92 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
93 |
+
return F.pad(x, (0, extra_padding))
|
94 |
+
|
95 |
+
|
96 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
97 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
98 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
99 |
+
"""
|
100 |
+
length = x.shape[-1]
|
101 |
+
padding_left, padding_right = paddings
|
102 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
103 |
+
if mode == 'reflect':
|
104 |
+
max_pad = max(padding_left, padding_right)
|
105 |
+
extra_pad = 0
|
106 |
+
if length <= max_pad:
|
107 |
+
extra_pad = max_pad - length + 1
|
108 |
+
x = F.pad(x, (0, extra_pad))
|
109 |
+
padded = F.pad(x, paddings, mode, value)
|
110 |
+
end = padded.shape[-1] - extra_pad
|
111 |
+
return padded[..., :end]
|
112 |
+
else:
|
113 |
+
return F.pad(x, paddings, mode, value)
|
114 |
+
|
115 |
+
|
116 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
117 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
118 |
+
padding_left, padding_right = paddings
|
119 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
120 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
121 |
+
end = x.shape[-1] - padding_right
|
122 |
+
return x[..., padding_left: end]
|
123 |
+
|
124 |
+
|
125 |
+
class NormConv1d(nn.Module):
|
126 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
127 |
+
to provide a uniform interface across normalization approaches.
|
128 |
+
"""
|
129 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
130 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
131 |
+
super().__init__()
|
132 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
133 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
134 |
+
self.norm_type = norm
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.conv(x)
|
138 |
+
x = self.norm(x)
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class NormConv2d(nn.Module):
|
143 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
144 |
+
to provide a uniform interface across normalization approaches.
|
145 |
+
"""
|
146 |
+
def __init__(self, *args, norm: str = 'none',
|
147 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
148 |
+
super().__init__()
|
149 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
150 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
151 |
+
self.norm_type = norm
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
x = self.conv(x)
|
155 |
+
x = self.norm(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class NormConvTranspose1d(nn.Module):
|
160 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
161 |
+
to provide a uniform interface across normalization approaches.
|
162 |
+
"""
|
163 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
164 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
165 |
+
super().__init__()
|
166 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
167 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
168 |
+
self.norm_type = norm
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
x = self.convtr(x)
|
172 |
+
x = self.norm(x)
|
173 |
+
return x
|
174 |
+
|
175 |
+
|
176 |
+
class NormConvTranspose2d(nn.Module):
|
177 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
178 |
+
to provide a uniform interface across normalization approaches.
|
179 |
+
"""
|
180 |
+
def __init__(self, *args, norm: str = 'none',
|
181 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
182 |
+
super().__init__()
|
183 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
184 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = self.convtr(x)
|
188 |
+
x = self.norm(x)
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
class SConv1d(nn.Module):
|
193 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
194 |
+
and normalization.
|
195 |
+
"""
|
196 |
+
def __init__(self, in_channels: int, out_channels: int,
|
197 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
198 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
199 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
200 |
+
pad_mode: str = 'reflect', **kwargs):
|
201 |
+
super().__init__()
|
202 |
+
# warn user on unusual setup between dilation and stride
|
203 |
+
if stride > 1 and dilation > 1:
|
204 |
+
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
|
205 |
+
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
206 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
207 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
208 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
209 |
+
self.causal = causal
|
210 |
+
self.pad_mode = pad_mode
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
B, C, T = x.shape
|
214 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
215 |
+
stride = self.conv.conv.stride[0]
|
216 |
+
dilation = self.conv.conv.dilation[0]
|
217 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
218 |
+
padding_total = kernel_size - stride
|
219 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
220 |
+
if self.causal:
|
221 |
+
# Left padding for causal
|
222 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
223 |
+
else:
|
224 |
+
# Asymmetric padding required for odd strides
|
225 |
+
padding_right = padding_total // 2
|
226 |
+
padding_left = padding_total - padding_right
|
227 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
228 |
+
return self.conv(x)
|
229 |
+
|
230 |
+
|
231 |
+
class SConvTranspose1d(nn.Module):
|
232 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
233 |
+
and normalization.
|
234 |
+
"""
|
235 |
+
def __init__(self, in_channels: int, out_channels: int,
|
236 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
237 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
238 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
239 |
+
super().__init__()
|
240 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
241 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
242 |
+
self.causal = causal
|
243 |
+
self.trim_right_ratio = trim_right_ratio
|
244 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
245 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
246 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
250 |
+
stride = self.convtr.convtr.stride[0]
|
251 |
+
padding_total = kernel_size - stride
|
252 |
+
|
253 |
+
y = self.convtr(x)
|
254 |
+
|
255 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
256 |
+
# removed at the very end, when keeping only the right length for the output,
|
257 |
+
# as removing it here would require also passing the length at the matching layer
|
258 |
+
# in the encoder.
|
259 |
+
if self.causal:
|
260 |
+
# Trim the padding on the right according to the specified ratio
|
261 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
262 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
263 |
+
padding_left = padding_total - padding_right
|
264 |
+
y = unpad1d(y, (padding_left, padding_right))
|
265 |
+
else:
|
266 |
+
# Asymmetric padding required for odd strides
|
267 |
+
padding_right = padding_total // 2
|
268 |
+
padding_left = padding_total - padding_right
|
269 |
+
y = unpad1d(y, (padding_left, padding_right))
|
270 |
+
return y
|
271 |
+
|
272 |
+
class SLSTM(nn.Module):
|
273 |
+
"""
|
274 |
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
275 |
+
Expects input as convolutional layout.
|
276 |
+
"""
|
277 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
278 |
+
super().__init__()
|
279 |
+
self.skip = skip
|
280 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
281 |
+
self.hidden = None
|
282 |
+
|
283 |
+
def forward(self, x):
|
284 |
+
x = x.permute(2, 0, 1)
|
285 |
+
if self.training:
|
286 |
+
y, _ = self.lstm(x)
|
287 |
+
else:
|
288 |
+
y, self.hidden = self.lstm(x, self.hidden)
|
289 |
+
if self.skip:
|
290 |
+
y = y + x
|
291 |
+
y = y.permute(1, 2, 0)
|
292 |
+
return y
|
modules/flow_matching.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from modules.diffusion_transformer import DiT
|
7 |
+
from modules.commons import sequence_mask
|
8 |
+
|
9 |
+
class BASECFM(torch.nn.Module, ABC):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
args,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.sigma_min = 1e-6
|
16 |
+
|
17 |
+
self.estimator = None
|
18 |
+
|
19 |
+
self.in_channels = args.DiT.in_channels
|
20 |
+
|
21 |
+
self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
|
22 |
+
|
23 |
+
if hasattr(args.DiT, 'zero_prompt_speech_token'):
|
24 |
+
self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
|
25 |
+
else:
|
26 |
+
self.zero_prompt_speech_token = False
|
27 |
+
|
28 |
+
@torch.inference_mode()
|
29 |
+
def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
|
30 |
+
"""Forward diffusion
|
31 |
+
|
32 |
+
Args:
|
33 |
+
mu (torch.Tensor): output of encoder
|
34 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
35 |
+
mask (torch.Tensor): output_mask
|
36 |
+
shape: (batch_size, 1, mel_timesteps)
|
37 |
+
n_timesteps (int): number of diffusion steps
|
38 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
39 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
40 |
+
shape: (batch_size, spk_emb_dim)
|
41 |
+
cond: Not used but kept for future purposes
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
sample: generated mel-spectrogram
|
45 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
46 |
+
"""
|
47 |
+
B, T = mu.size(0), mu.size(1)
|
48 |
+
z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
|
49 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
50 |
+
return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
|
51 |
+
|
52 |
+
def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
|
53 |
+
"""
|
54 |
+
Fixed euler solver for ODEs.
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): random noise
|
57 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
58 |
+
shape: (n_timesteps + 1,)
|
59 |
+
mu (torch.Tensor): output of encoder
|
60 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
61 |
+
mask (torch.Tensor): output_mask
|
62 |
+
shape: (batch_size, 1, mel_timesteps)
|
63 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
64 |
+
shape: (batch_size, spk_emb_dim)
|
65 |
+
cond: Not used but kept for future purposes
|
66 |
+
"""
|
67 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
68 |
+
|
69 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
70 |
+
# Or in future might add like a return_all_steps flag
|
71 |
+
sol = []
|
72 |
+
# apply prompt
|
73 |
+
prompt_len = prompt.size(-1)
|
74 |
+
prompt_x = torch.zeros_like(x)
|
75 |
+
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
76 |
+
x[..., :prompt_len] = 0
|
77 |
+
if self.zero_prompt_speech_token:
|
78 |
+
mu[..., :prompt_len] = 0
|
79 |
+
for step in range(1, len(t_span)):
|
80 |
+
dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
|
81 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
82 |
+
if inference_cfg_rate > 0:
|
83 |
+
cfg_dphi_dt = self.estimator(
|
84 |
+
x, torch.zeros_like(prompt_x), x_lens, t.unsqueeze(0),
|
85 |
+
torch.zeros_like(style),
|
86 |
+
torch.zeros_like(mu), None
|
87 |
+
)
|
88 |
+
dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
|
89 |
+
inference_cfg_rate * cfg_dphi_dt)
|
90 |
+
x = x + dt * dphi_dt
|
91 |
+
t = t + dt
|
92 |
+
sol.append(x)
|
93 |
+
if step < len(t_span) - 1:
|
94 |
+
dt = t_span[step + 1] - t
|
95 |
+
x[:, :, :prompt_len] = 0
|
96 |
+
|
97 |
+
return sol[-1]
|
98 |
+
|
99 |
+
def forward(self, x1, x_lens, prompt_lens, mu, style, f0=None):
|
100 |
+
"""Computes diffusion loss
|
101 |
+
|
102 |
+
Args:
|
103 |
+
x1 (torch.Tensor): Target
|
104 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
105 |
+
mask (torch.Tensor): target mask
|
106 |
+
shape: (batch_size, 1, mel_timesteps)
|
107 |
+
mu (torch.Tensor): output of encoder
|
108 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
109 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
110 |
+
shape: (batch_size, spk_emb_dim)
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
loss: conditional flow matching loss
|
114 |
+
y: conditional flow
|
115 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
116 |
+
"""
|
117 |
+
b, _, t = x1.shape
|
118 |
+
|
119 |
+
# random timestep
|
120 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
|
121 |
+
# sample noise p(x_0)
|
122 |
+
z = torch.randn_like(x1)
|
123 |
+
|
124 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
125 |
+
u = x1 - (1 - self.sigma_min) * z
|
126 |
+
|
127 |
+
prompt = torch.zeros_like(x1)
|
128 |
+
for bib in range(b):
|
129 |
+
prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
|
130 |
+
# range covered by prompt are set to 0
|
131 |
+
y[bib, :, :prompt_lens[bib]] = 0
|
132 |
+
if self.zero_prompt_speech_token:
|
133 |
+
mu[bib, :, :prompt_lens[bib]] = 0
|
134 |
+
|
135 |
+
estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu, f0)
|
136 |
+
loss = 0
|
137 |
+
for bib in range(b):
|
138 |
+
loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
|
139 |
+
loss /= b
|
140 |
+
|
141 |
+
return loss, y
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
class CFM(BASECFM):
|
146 |
+
def __init__(self, args):
|
147 |
+
super().__init__(
|
148 |
+
args
|
149 |
+
)
|
150 |
+
if args.dit_type == "DiT":
|
151 |
+
self.estimator = DiT(args)
|
152 |
+
else:
|
153 |
+
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
|
modules/gpt_fast/__pycache__/model.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
modules/gpt_fast/generate.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import itertools
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch._dynamo.config
|
14 |
+
import torch._inductor.config
|
15 |
+
|
16 |
+
def device_sync(device):
|
17 |
+
if "cuda" in device:
|
18 |
+
torch.cuda.synchronize(device)
|
19 |
+
elif ("cpu" in device) or ("mps" in device):
|
20 |
+
pass
|
21 |
+
else:
|
22 |
+
print(f"device={device} is not yet suppported")
|
23 |
+
|
24 |
+
|
25 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
26 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
27 |
+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
28 |
+
|
29 |
+
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
30 |
+
|
31 |
+
# support running without installing as a package
|
32 |
+
wd = Path(__file__).parent.parent.resolve()
|
33 |
+
sys.path.append(str(wd))
|
34 |
+
|
35 |
+
from model import Transformer
|
36 |
+
from tokenizer import get_tokenizer
|
37 |
+
|
38 |
+
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
39 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
40 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
41 |
+
|
42 |
+
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
43 |
+
logits = logits / max(temperature, 1e-5)
|
44 |
+
|
45 |
+
if top_k is not None:
|
46 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
47 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
48 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
49 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
50 |
+
return probs
|
51 |
+
|
52 |
+
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
53 |
+
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
54 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
55 |
+
return idx_next, probs
|
56 |
+
|
57 |
+
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
|
58 |
+
# input_pos: [B, S]
|
59 |
+
logits = model(x, input_pos)
|
60 |
+
return sample(logits, **sampling_kwargs)[0]
|
61 |
+
|
62 |
+
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
63 |
+
# input_pos: [B, 1]
|
64 |
+
assert input_pos.shape[-1] == 1
|
65 |
+
logits = model(x, input_pos)
|
66 |
+
return sample(logits, **sampling_kwargs)
|
67 |
+
|
68 |
+
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
|
69 |
+
new_tokens, new_probs = [], []
|
70 |
+
for i in range(num_new_tokens):
|
71 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
|
72 |
+
next_token, next_prob = decode_one_token(
|
73 |
+
model, cur_token, input_pos, **sampling_kwargs
|
74 |
+
)
|
75 |
+
input_pos += 1
|
76 |
+
new_tokens.append(next_token.clone())
|
77 |
+
callback(new_tokens[-1])
|
78 |
+
new_probs.append(next_prob.clone())
|
79 |
+
cur_token = next_token.view(1, -1)
|
80 |
+
|
81 |
+
return new_tokens, new_probs
|
82 |
+
|
83 |
+
|
84 |
+
def model_forward(model, x, input_pos):
|
85 |
+
return model(x, input_pos)
|
86 |
+
|
87 |
+
def speculative_decode(
|
88 |
+
model: Transformer,
|
89 |
+
draft_model: Transformer,
|
90 |
+
cur_token: torch.Tensor,
|
91 |
+
input_pos: int,
|
92 |
+
speculate_k: int,
|
93 |
+
**sampling_kwargs
|
94 |
+
) -> torch.Tensor:
|
95 |
+
# draft model inference sequentially
|
96 |
+
device = cur_token.device
|
97 |
+
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
|
98 |
+
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
|
99 |
+
|
100 |
+
draft_tokens = torch.cat(draft_tokens)
|
101 |
+
# parallel inference on target model using draft tokens
|
102 |
+
target_logits = model_forward(
|
103 |
+
model,
|
104 |
+
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
|
105 |
+
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
|
106 |
+
)
|
107 |
+
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
|
108 |
+
draft_probs = torch.stack(draft_probs)
|
109 |
+
# q: target prob, p: draft prob
|
110 |
+
# q >= p: always accept draft token
|
111 |
+
# q < p: q/p prob to accept draft token
|
112 |
+
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
113 |
+
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
|
114 |
+
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
|
115 |
+
rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
|
116 |
+
|
117 |
+
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
|
118 |
+
accept_length = speculate_k + 1
|
119 |
+
last_token = multinomial_sample_one_no_sync(target_probs[-1])
|
120 |
+
# fill last token into draft model
|
121 |
+
model_forward(
|
122 |
+
draft_model,
|
123 |
+
draft_tokens[-1].view(1, -1),
|
124 |
+
orig_input_pos + speculate_k,
|
125 |
+
)
|
126 |
+
return torch.cat([draft_tokens, last_token])
|
127 |
+
else:
|
128 |
+
accept_length = rejected_locations[0].item()
|
129 |
+
p = draft_probs[accept_length]
|
130 |
+
q = target_probs[accept_length]
|
131 |
+
new = q - p
|
132 |
+
new = torch.where(new > 0, new, 0.0)
|
133 |
+
new = new / new.sum()
|
134 |
+
next_token = multinomial_sample_one_no_sync(new)
|
135 |
+
return torch.cat([draft_tokens[:accept_length], next_token])
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def generate(
|
139 |
+
model: Transformer,
|
140 |
+
prompt: torch.Tensor,
|
141 |
+
max_new_tokens: int,
|
142 |
+
*,
|
143 |
+
interactive: bool,
|
144 |
+
draft_model: Transformer,
|
145 |
+
speculate_k: Optional[int] = 8,
|
146 |
+
callback = lambda x: x,
|
147 |
+
**sampling_kwargs
|
148 |
+
) -> torch.Tensor:
|
149 |
+
"""
|
150 |
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
151 |
+
"""
|
152 |
+
|
153 |
+
is_speculative = draft_model is not None
|
154 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
155 |
+
T = prompt.size(0)
|
156 |
+
T_new = T + max_new_tokens
|
157 |
+
if interactive:
|
158 |
+
max_seq_length = 350
|
159 |
+
else:
|
160 |
+
max_seq_length = min(T_new, model.config.block_size)
|
161 |
+
|
162 |
+
device, dtype = prompt.device, prompt.dtype
|
163 |
+
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
|
164 |
+
with torch.device(device):
|
165 |
+
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
166 |
+
if is_speculative and draft_model is not model:
|
167 |
+
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
|
168 |
+
|
169 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
170 |
+
empty = torch.empty(T_new, dtype=dtype, device=device)
|
171 |
+
empty[:T] = prompt
|
172 |
+
seq = empty
|
173 |
+
input_pos = torch.arange(0, T, device=device)
|
174 |
+
|
175 |
+
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
|
176 |
+
if is_speculative:
|
177 |
+
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
|
178 |
+
seq[T] = next_token
|
179 |
+
|
180 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
181 |
+
accept_counts = [0] * (speculate_k + 1)
|
182 |
+
|
183 |
+
if is_speculative:
|
184 |
+
input_pos = input_pos.item() # for speculative decoding easier to keep on host
|
185 |
+
while input_pos < T_new - 1:
|
186 |
+
cur_token = next_token.view(())
|
187 |
+
|
188 |
+
next_tokens = speculative_decode(
|
189 |
+
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
|
190 |
+
)
|
191 |
+
|
192 |
+
accept_counts[len(next_tokens) - 1] += 1
|
193 |
+
num_added = min(T_new - input_pos - 1, len(next_tokens))
|
194 |
+
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
|
195 |
+
for i in next_tokens[: num_added,]:
|
196 |
+
callback(i)
|
197 |
+
input_pos = input_pos + num_added
|
198 |
+
next_token = next_tokens[-1]
|
199 |
+
else:
|
200 |
+
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
|
201 |
+
seq[T + 1:] = torch.cat(generated_tokens)
|
202 |
+
|
203 |
+
generate_stats = {
|
204 |
+
'accept_counts': accept_counts
|
205 |
+
}
|
206 |
+
return seq, generate_stats
|
207 |
+
|
208 |
+
def encode_tokens(tokenizer, string, bos=True, device=default_device):
|
209 |
+
tokens = tokenizer.encode(string)
|
210 |
+
if bos:
|
211 |
+
tokens = [tokenizer.bos_id()] + tokens
|
212 |
+
return torch.tensor(tokens, dtype=torch.int, device=device)
|
213 |
+
|
214 |
+
def _load_model(checkpoint_path, device, precision, use_tp):
|
215 |
+
use_cuda = 'cuda' in device
|
216 |
+
with torch.device('meta'):
|
217 |
+
model = Transformer.from_name(checkpoint_path.parent.name)
|
218 |
+
|
219 |
+
if "int8" in str(checkpoint_path):
|
220 |
+
print("Using int8 weight-only quantization!")
|
221 |
+
from quantize import WeightOnlyInt8QuantHandler
|
222 |
+
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
223 |
+
model = simple_quantizer.convert_for_runtime()
|
224 |
+
|
225 |
+
if "int4" in str(checkpoint_path):
|
226 |
+
print("Using int4 weight-only quantization!")
|
227 |
+
path_comps = checkpoint_path.name.split(".")
|
228 |
+
groupsize = int(path_comps[-2][1:])
|
229 |
+
from quantize import WeightOnlyInt4QuantHandler
|
230 |
+
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
231 |
+
model = simple_quantizer.convert_for_runtime()
|
232 |
+
|
233 |
+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
234 |
+
if "model" in checkpoint and "stories" in str(checkpoint_path):
|
235 |
+
checkpoint = checkpoint["model"]
|
236 |
+
model.load_state_dict(checkpoint, assign=True)
|
237 |
+
|
238 |
+
if use_tp:
|
239 |
+
from tp import apply_tp
|
240 |
+
print("Applying tensor parallel to model ...")
|
241 |
+
apply_tp(model)
|
242 |
+
|
243 |
+
model = model.to(device=device, dtype=precision)
|
244 |
+
return model.eval()
|
245 |
+
|
246 |
+
def _get_model_size(model):
|
247 |
+
model_size = 0
|
248 |
+
for name, child in model.named_children():
|
249 |
+
if not isinstance(child, torch.nn.Embedding):
|
250 |
+
model_size += sum(
|
251 |
+
[
|
252 |
+
p.numel() * p.dtype.itemsize
|
253 |
+
for p in itertools.chain(child.parameters(), child.buffers())
|
254 |
+
]
|
255 |
+
)
|
256 |
+
return model_size
|
257 |
+
|
258 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
259 |
+
|
260 |
+
def main(
|
261 |
+
prompt: str = "Hello, my name is",
|
262 |
+
interactive: bool = False,
|
263 |
+
num_samples: int = 5,
|
264 |
+
max_new_tokens: int = 100,
|
265 |
+
top_k: int = 200,
|
266 |
+
temperature: float = 0.8,
|
267 |
+
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
|
268 |
+
compile: bool = True,
|
269 |
+
compile_prefill: bool = False,
|
270 |
+
profile: Optional[Path] = None,
|
271 |
+
draft_checkpoint_path: Optional[Path] = None,
|
272 |
+
speculate_k: int = 5,
|
273 |
+
device=default_device,
|
274 |
+
) -> None:
|
275 |
+
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
|
276 |
+
"""
|
277 |
+
assert checkpoint_path.is_file(), checkpoint_path
|
278 |
+
|
279 |
+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
280 |
+
assert tokenizer_path.is_file(), str(tokenizer_path)
|
281 |
+
|
282 |
+
global print
|
283 |
+
from tp import maybe_init_dist
|
284 |
+
rank = maybe_init_dist()
|
285 |
+
use_tp = rank is not None
|
286 |
+
if use_tp:
|
287 |
+
if rank != 0:
|
288 |
+
# only print on rank 0
|
289 |
+
print = lambda *args, **kwargs: None
|
290 |
+
|
291 |
+
print(f"Using device={device}")
|
292 |
+
precision = torch.bfloat16
|
293 |
+
is_speculative = draft_checkpoint_path is not None
|
294 |
+
is_chat = "chat" in str(checkpoint_path)
|
295 |
+
|
296 |
+
print("Loading model ...")
|
297 |
+
t0 = time.time()
|
298 |
+
model = _load_model(checkpoint_path, device, precision, use_tp)
|
299 |
+
|
300 |
+
if is_speculative:
|
301 |
+
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
|
302 |
+
else:
|
303 |
+
draft_model = None
|
304 |
+
|
305 |
+
device_sync(device=device) # MKG
|
306 |
+
print(f"Time to load model: {time.time() - t0:.02f} seconds")
|
307 |
+
|
308 |
+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
309 |
+
|
310 |
+
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
311 |
+
prompt_length = encoded.size(0)
|
312 |
+
|
313 |
+
torch.manual_seed(1234)
|
314 |
+
model_size = _get_model_size(model)
|
315 |
+
if compile:
|
316 |
+
if is_speculative and use_tp: # and ("cuda" in device):
|
317 |
+
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
|
318 |
+
|
319 |
+
if is_speculative:
|
320 |
+
global model_forward, logits_to_prob
|
321 |
+
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
|
322 |
+
|
323 |
+
global decode_one_token, prefill
|
324 |
+
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
325 |
+
|
326 |
+
# Uncomment to squeeze more perf out of prefill
|
327 |
+
if compile_prefill:
|
328 |
+
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
|
329 |
+
|
330 |
+
|
331 |
+
aggregate_metrics = {
|
332 |
+
'tokens_per_sec': [],
|
333 |
+
'accept_counts': [],
|
334 |
+
}
|
335 |
+
start = -1 if compile else 0
|
336 |
+
|
337 |
+
for i in range(start, num_samples):
|
338 |
+
device_sync(device=device) # MKG
|
339 |
+
if i >= 0 and interactive:
|
340 |
+
prompt = input("What is your prompt? ")
|
341 |
+
if is_chat:
|
342 |
+
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
|
343 |
+
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
|
344 |
+
|
345 |
+
if interactive and i >= 0:
|
346 |
+
buffer = []
|
347 |
+
period_id = tokenizer.encode('.')[0]
|
348 |
+
done_generating = False
|
349 |
+
def callback(x):
|
350 |
+
nonlocal done_generating
|
351 |
+
if done_generating:
|
352 |
+
return
|
353 |
+
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
|
354 |
+
if x.item() == tokenizer.eos_id():
|
355 |
+
done_generating = True
|
356 |
+
if len(buffer) == 4 or done_generating:
|
357 |
+
print(''.join(buffer), end='', flush=True)
|
358 |
+
buffer.clear()
|
359 |
+
# print(, end='', flush=True)
|
360 |
+
else:
|
361 |
+
callback = lambda x : x
|
362 |
+
t0 = time.perf_counter()
|
363 |
+
import contextlib
|
364 |
+
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
|
365 |
+
prof = contextlib.nullcontext()
|
366 |
+
else:
|
367 |
+
torch.profiler._utils._init_for_cuda_graphs()
|
368 |
+
prof = torch.profiler.profile()
|
369 |
+
with prof:
|
370 |
+
y, metrics = generate(
|
371 |
+
model,
|
372 |
+
encoded,
|
373 |
+
max_new_tokens,
|
374 |
+
draft_model=draft_model,
|
375 |
+
speculate_k=speculate_k,
|
376 |
+
interactive=interactive,
|
377 |
+
callback=callback,
|
378 |
+
temperature=temperature,
|
379 |
+
top_k=top_k,
|
380 |
+
)
|
381 |
+
aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
|
382 |
+
if i == -1:
|
383 |
+
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
384 |
+
continue
|
385 |
+
if hasattr(prof, "export_chrome_trace"):
|
386 |
+
if use_tp:
|
387 |
+
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
|
388 |
+
else:
|
389 |
+
prof.export_chrome_trace(f"{profile}.json")
|
390 |
+
device_sync(device=device) # MKG
|
391 |
+
t = time.perf_counter() - t0
|
392 |
+
|
393 |
+
if not interactive:
|
394 |
+
print(tokenizer.decode(y.tolist()))
|
395 |
+
else:
|
396 |
+
print()
|
397 |
+
tokens_generated = y.size(0) - prompt_length
|
398 |
+
tokens_sec = tokens_generated / t
|
399 |
+
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
|
400 |
+
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
|
401 |
+
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
|
402 |
+
print("==========")
|
403 |
+
if is_speculative:
|
404 |
+
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
|
405 |
+
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
|
406 |
+
print(f"Acceptance probs: {acceptance_probs}")
|
407 |
+
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
|
408 |
+
|
409 |
+
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
|
410 |
+
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == '__main__':
|
414 |
+
import argparse
|
415 |
+
parser = argparse.ArgumentParser(description='Your CLI description.')
|
416 |
+
|
417 |
+
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
|
418 |
+
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
|
419 |
+
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
|
420 |
+
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
|
421 |
+
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
|
422 |
+
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
|
423 |
+
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
|
424 |
+
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
|
425 |
+
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
|
426 |
+
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
|
427 |
+
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
|
428 |
+
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
|
429 |
+
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
|
430 |
+
|
431 |
+
args = parser.parse_args()
|
432 |
+
main(
|
433 |
+
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
|
434 |
+
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
|
435 |
+
args.speculate_k, args.device
|
436 |
+
)
|
modules/gpt_fast/model.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch import Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
def find_multiple(n: int, k: int) -> int:
|
16 |
+
if n % k == 0:
|
17 |
+
return n
|
18 |
+
return n + k - (n % k)
|
19 |
+
|
20 |
+
class AdaptiveLayerNorm(nn.Module):
|
21 |
+
r"""Adaptive Layer Normalization"""
|
22 |
+
|
23 |
+
def __init__(self, d_model, norm) -> None:
|
24 |
+
super(AdaptiveLayerNorm, self).__init__()
|
25 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
26 |
+
self.norm = norm
|
27 |
+
self.d_model = d_model
|
28 |
+
self.eps = self.norm.eps
|
29 |
+
|
30 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
31 |
+
if embedding is None:
|
32 |
+
return self.norm(input)
|
33 |
+
weight, bias = torch.split(
|
34 |
+
self.project_layer(embedding),
|
35 |
+
split_size_or_sections=self.d_model,
|
36 |
+
dim=-1,
|
37 |
+
)
|
38 |
+
return weight * self.norm(input) + bias
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ModelArgs:
|
43 |
+
block_size: int = 2048
|
44 |
+
vocab_size: int = 32000
|
45 |
+
n_layer: int = 32
|
46 |
+
n_head: int = 32
|
47 |
+
dim: int = 4096
|
48 |
+
intermediate_size: int = None
|
49 |
+
n_local_heads: int = -1
|
50 |
+
head_dim: int = 64
|
51 |
+
rope_base: float = 10000
|
52 |
+
norm_eps: float = 1e-5
|
53 |
+
has_cross_attention: bool = False
|
54 |
+
context_dim: int = 0
|
55 |
+
uvit_skip_connection: bool = False
|
56 |
+
|
57 |
+
def __post_init__(self):
|
58 |
+
if self.n_local_heads == -1:
|
59 |
+
self.n_local_heads = self.n_head
|
60 |
+
if self.intermediate_size is None:
|
61 |
+
hidden_dim = 4 * self.dim
|
62 |
+
n_hidden = int(2 * hidden_dim / 3)
|
63 |
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
64 |
+
# self.head_dim = self.dim // self.n_head
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def from_name(cls, name: str):
|
68 |
+
if name in transformer_configs:
|
69 |
+
return cls(**transformer_configs[name])
|
70 |
+
# fuzzy search
|
71 |
+
config = [config for config in transformer_configs if config.lower() in str(name).lower()]
|
72 |
+
|
73 |
+
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
|
74 |
+
# take longer name (as it have more symbols matched)
|
75 |
+
if len(config) > 1:
|
76 |
+
config.sort(key=len, reverse=True)
|
77 |
+
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
|
78 |
+
|
79 |
+
return cls(**transformer_configs[config[0]])
|
80 |
+
|
81 |
+
|
82 |
+
transformer_configs = {
|
83 |
+
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000),
|
84 |
+
"7B": dict(n_layer=32, n_head=32, dim=4096),
|
85 |
+
"13B": dict(n_layer=40, n_head=40, dim=5120),
|
86 |
+
"30B": dict(n_layer=60, n_head=52, dim=6656),
|
87 |
+
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016,
|
88 |
+
rope_base=1000000), # CodeLlama-34B-Python-hf
|
89 |
+
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
|
90 |
+
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
|
91 |
+
"stories15M": dict(n_layer=6, n_head=6, dim=288),
|
92 |
+
"stories110M": dict(n_layer=12, n_head=12, dim=768),
|
93 |
+
|
94 |
+
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336,
|
95 |
+
vocab_size=128256, rope_base=500000),
|
96 |
+
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672,
|
97 |
+
vocab_size=128256, rope_base=500000),
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
class KVCache(nn.Module):
|
102 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
|
103 |
+
super().__init__()
|
104 |
+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
105 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
106 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
107 |
+
|
108 |
+
def update(self, input_pos, k_val, v_val):
|
109 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
110 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
111 |
+
|
112 |
+
k_out = self.k_cache
|
113 |
+
v_out = self.v_cache
|
114 |
+
k_out[:, :, input_pos] = k_val
|
115 |
+
v_out[:, :, input_pos] = v_val
|
116 |
+
|
117 |
+
return k_out, v_out
|
118 |
+
|
119 |
+
|
120 |
+
class Transformer(nn.Module):
|
121 |
+
def __init__(self, config: ModelArgs) -> None:
|
122 |
+
super().__init__()
|
123 |
+
self.config = config
|
124 |
+
|
125 |
+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
126 |
+
self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
127 |
+
|
128 |
+
self.freqs_cis: Optional[Tensor] = None
|
129 |
+
self.mask_cache: Optional[Tensor] = None
|
130 |
+
self.max_batch_size = -1
|
131 |
+
self.max_seq_length = -1
|
132 |
+
|
133 |
+
def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=True):
|
134 |
+
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
|
135 |
+
return
|
136 |
+
head_dim = self.config.dim // self.config.n_head
|
137 |
+
max_seq_length = find_multiple(max_seq_length, 8)
|
138 |
+
self.max_seq_length = max_seq_length
|
139 |
+
self.max_batch_size = max_batch_size
|
140 |
+
dtype = self.norm.project_layer.weight.dtype
|
141 |
+
device = self.norm.project_layer.weight.device
|
142 |
+
|
143 |
+
if not self.training and use_kv_cache:
|
144 |
+
for b in self.layers:
|
145 |
+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype).to(device)
|
146 |
+
|
147 |
+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
|
148 |
+
self.config.rope_base, dtype).to(device)
|
149 |
+
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
|
150 |
+
self.use_kv_cache = use_kv_cache
|
151 |
+
self.uvit_skip_connection = self.config.uvit_skip_connection
|
152 |
+
if self.uvit_skip_connection:
|
153 |
+
self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
|
154 |
+
self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
|
155 |
+
else:
|
156 |
+
self.layers_emit_skip = []
|
157 |
+
self.layers_receive_skip = []
|
158 |
+
|
159 |
+
def forward(self,
|
160 |
+
x: Tensor,
|
161 |
+
c: Tensor,
|
162 |
+
input_pos: Optional[Tensor] = None,
|
163 |
+
mask: Optional[Tensor] = None,
|
164 |
+
context: Optional[Tensor] = None,
|
165 |
+
context_input_pos: Optional[Tensor] = None,
|
166 |
+
cross_attention_mask: Optional[Tensor] = None,
|
167 |
+
) -> Tensor:
|
168 |
+
assert self.freqs_cis is not None, "Caches must be initialized first"
|
169 |
+
if mask is None: # in case of non-causal model
|
170 |
+
if not self.training and self.use_kv_cache:
|
171 |
+
mask = self.causal_mask[None, None, input_pos]
|
172 |
+
else:
|
173 |
+
mask = self.causal_mask[None, None, input_pos]
|
174 |
+
mask = mask[..., input_pos]
|
175 |
+
freqs_cis = self.freqs_cis[input_pos]
|
176 |
+
if context is not None:
|
177 |
+
context_freqs_cis = self.freqs_cis[context_input_pos]
|
178 |
+
else:
|
179 |
+
context_freqs_cis = None
|
180 |
+
skip_in_x_list = []
|
181 |
+
for i, layer in enumerate(self.layers):
|
182 |
+
if self.uvit_skip_connection and i in self.layers_receive_skip:
|
183 |
+
skip_in_x = skip_in_x_list.pop(-1)
|
184 |
+
else:
|
185 |
+
skip_in_x = None
|
186 |
+
x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
|
187 |
+
if self.uvit_skip_connection and i in self.layers_emit_skip:
|
188 |
+
skip_in_x_list.append(x)
|
189 |
+
x = self.norm(x, c)
|
190 |
+
return x
|
191 |
+
|
192 |
+
@classmethod
|
193 |
+
def from_name(cls, name: str):
|
194 |
+
return cls(ModelArgs.from_name(name))
|
195 |
+
|
196 |
+
|
197 |
+
class TransformerBlock(nn.Module):
|
198 |
+
def __init__(self, config: ModelArgs) -> None:
|
199 |
+
super().__init__()
|
200 |
+
self.attention = Attention(config)
|
201 |
+
self.feed_forward = FeedForward(config)
|
202 |
+
self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
203 |
+
self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
204 |
+
|
205 |
+
if config.has_cross_attention:
|
206 |
+
self.has_cross_attention = True
|
207 |
+
self.cross_attention = Attention(config, is_cross_attention=True)
|
208 |
+
self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
|
209 |
+
else:
|
210 |
+
self.has_cross_attention = False
|
211 |
+
|
212 |
+
if config.uvit_skip_connection:
|
213 |
+
self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
|
214 |
+
self.uvit_skip_connection = True
|
215 |
+
else:
|
216 |
+
self.uvit_skip_connection = False
|
217 |
+
|
218 |
+
def forward(self,
|
219 |
+
x: Tensor,
|
220 |
+
c: Tensor,
|
221 |
+
input_pos: Tensor,
|
222 |
+
freqs_cis: Tensor,
|
223 |
+
mask: Tensor,
|
224 |
+
context: Optional[Tensor] = None,
|
225 |
+
context_freqs_cis: Optional[Tensor] = None,
|
226 |
+
cross_attention_mask: Optional[Tensor] = None,
|
227 |
+
skip_in_x: Optional[Tensor] = None,
|
228 |
+
) -> Tensor:
|
229 |
+
if self.uvit_skip_connection and skip_in_x is not None:
|
230 |
+
x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
|
231 |
+
h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
|
232 |
+
if self.has_cross_attention:
|
233 |
+
h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
|
234 |
+
out = h + self.feed_forward(self.ffn_norm(h, c))
|
235 |
+
return out
|
236 |
+
|
237 |
+
|
238 |
+
class Attention(nn.Module):
|
239 |
+
def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
|
240 |
+
super().__init__()
|
241 |
+
assert config.dim % config.n_head == 0
|
242 |
+
|
243 |
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
244 |
+
# key, query, value projections for all heads, but in a batch
|
245 |
+
if is_cross_attention:
|
246 |
+
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
|
247 |
+
self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
|
248 |
+
else:
|
249 |
+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
250 |
+
self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
251 |
+
self.kv_cache = None
|
252 |
+
|
253 |
+
self.n_head = config.n_head
|
254 |
+
self.head_dim = config.head_dim
|
255 |
+
self.n_local_heads = config.n_local_heads
|
256 |
+
self.dim = config.dim
|
257 |
+
# self._register_load_state_dict_pre_hook(self.load_hook)
|
258 |
+
|
259 |
+
# def load_hook(self, state_dict, prefix, *args):
|
260 |
+
# if prefix + "wq.weight" in state_dict:
|
261 |
+
# wq = state_dict.pop(prefix + "wq.weight")
|
262 |
+
# wk = state_dict.pop(prefix + "wk.weight")
|
263 |
+
# wv = state_dict.pop(prefix + "wv.weight")
|
264 |
+
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
265 |
+
|
266 |
+
def forward(self,
|
267 |
+
x: Tensor,
|
268 |
+
freqs_cis: Tensor,
|
269 |
+
mask: Tensor,
|
270 |
+
input_pos: Optional[Tensor] = None,
|
271 |
+
context: Optional[Tensor] = None,
|
272 |
+
context_freqs_cis: Optional[Tensor] = None,
|
273 |
+
) -> Tensor:
|
274 |
+
bsz, seqlen, _ = x.shape
|
275 |
+
|
276 |
+
kv_size = self.n_local_heads * self.head_dim
|
277 |
+
if context is None:
|
278 |
+
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
279 |
+
context_seqlen = seqlen
|
280 |
+
else:
|
281 |
+
q = self.wq(x)
|
282 |
+
k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
|
283 |
+
context_seqlen = context.shape[1]
|
284 |
+
|
285 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
286 |
+
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
287 |
+
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
288 |
+
|
289 |
+
q = apply_rotary_emb(q, freqs_cis)
|
290 |
+
k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
|
291 |
+
|
292 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
293 |
+
|
294 |
+
if self.kv_cache is not None:
|
295 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
296 |
+
|
297 |
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
298 |
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
299 |
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
300 |
+
|
301 |
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
|
302 |
+
|
303 |
+
y = self.wo(y)
|
304 |
+
return y
|
305 |
+
|
306 |
+
|
307 |
+
class FeedForward(nn.Module):
|
308 |
+
def __init__(self, config: ModelArgs) -> None:
|
309 |
+
super().__init__()
|
310 |
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
311 |
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
312 |
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
313 |
+
|
314 |
+
def forward(self, x: Tensor) -> Tensor:
|
315 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
316 |
+
|
317 |
+
|
318 |
+
class RMSNorm(nn.Module):
|
319 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
320 |
+
super().__init__()
|
321 |
+
self.eps = eps
|
322 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
323 |
+
|
324 |
+
def _norm(self, x):
|
325 |
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
326 |
+
|
327 |
+
def forward(self, x: Tensor) -> Tensor:
|
328 |
+
output = self._norm(x.float()).type_as(x)
|
329 |
+
return output * self.weight
|
330 |
+
|
331 |
+
|
332 |
+
def precompute_freqs_cis(
|
333 |
+
seq_len: int, n_elem: int, base: int = 10000,
|
334 |
+
dtype: torch.dtype = torch.bfloat16
|
335 |
+
) -> Tensor:
|
336 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
337 |
+
t = torch.arange(seq_len, device=freqs.device)
|
338 |
+
freqs = torch.outer(t, freqs)
|
339 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
340 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
341 |
+
return cache.to(dtype=dtype)
|
342 |
+
|
343 |
+
|
344 |
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
345 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
346 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
347 |
+
x_out2 = torch.stack(
|
348 |
+
[
|
349 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
350 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
351 |
+
],
|
352 |
+
-1,
|
353 |
+
)
|
354 |
+
|
355 |
+
x_out2 = x_out2.flatten(3)
|
356 |
+
return x_out2.type_as(x)
|
modules/gpt_fast/quantize.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import time
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from tokenizer import get_tokenizer
|
13 |
+
|
14 |
+
try:
|
15 |
+
from GPTQ import GenericGPTQRunner, InputRecorder
|
16 |
+
from eval import get_task_dict, evaluate, lm_eval
|
17 |
+
except:
|
18 |
+
pass
|
19 |
+
|
20 |
+
from model import Transformer
|
21 |
+
|
22 |
+
##### Quantization Primitives ######
|
23 |
+
|
24 |
+
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
25 |
+
# assumes symmetric quantization
|
26 |
+
# assumes axis == 0
|
27 |
+
# assumes dense memory format
|
28 |
+
# TODO(future): relax ^ as needed
|
29 |
+
|
30 |
+
# default setup for affine quantization of activations
|
31 |
+
eps = torch.finfo(torch.float32).eps
|
32 |
+
|
33 |
+
# get min and max
|
34 |
+
min_val, max_val = torch.aminmax(x, dim=1)
|
35 |
+
|
36 |
+
# calculate scales and zero_points based on min and max
|
37 |
+
# reference: https://fburl.com/code/srbiybme
|
38 |
+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
39 |
+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
40 |
+
device = min_val_neg.device
|
41 |
+
|
42 |
+
# reference: https://fburl.com/code/4wll53rk
|
43 |
+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
44 |
+
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
45 |
+
# ensure scales is the same dtype as the original tensor
|
46 |
+
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
47 |
+
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
48 |
+
|
49 |
+
# quantize based on qmin/qmax/scales/zp
|
50 |
+
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
|
51 |
+
x_div = x / scales.unsqueeze(-1)
|
52 |
+
x_round = torch.round(x_div)
|
53 |
+
x_zp = x_round + zero_points.unsqueeze(-1)
|
54 |
+
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
55 |
+
|
56 |
+
return quant, scales, zero_points
|
57 |
+
|
58 |
+
def get_group_qparams(w, n_bit=4, groupsize=128):
|
59 |
+
# needed for GPTQ with padding
|
60 |
+
if groupsize > w.shape[-1]:
|
61 |
+
groupsize = w.shape[-1]
|
62 |
+
assert groupsize > 1
|
63 |
+
assert w.shape[-1] % groupsize == 0
|
64 |
+
assert w.dim() == 2
|
65 |
+
|
66 |
+
to_quant = w.reshape(-1, groupsize)
|
67 |
+
assert torch.isnan(to_quant).sum() == 0
|
68 |
+
|
69 |
+
max_val = to_quant.amax(dim=1, keepdim=True)
|
70 |
+
min_val = to_quant.amin(dim=1, keepdim=True)
|
71 |
+
max_int = 2**n_bit - 1
|
72 |
+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
73 |
+
zeros = min_val + scales * (2 ** (n_bit - 1))
|
74 |
+
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
75 |
+
torch.bfloat16
|
76 |
+
).reshape(w.shape[0], -1)
|
77 |
+
|
78 |
+
|
79 |
+
def pack_scales_and_zeros(scales, zeros):
|
80 |
+
assert scales.shape == zeros.shape
|
81 |
+
assert scales.dtype == torch.bfloat16
|
82 |
+
assert zeros.dtype == torch.bfloat16
|
83 |
+
return (
|
84 |
+
torch.cat(
|
85 |
+
[
|
86 |
+
scales.reshape(scales.size(0), scales.size(1), 1),
|
87 |
+
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
88 |
+
],
|
89 |
+
2,
|
90 |
+
)
|
91 |
+
.transpose(0, 1)
|
92 |
+
.contiguous()
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def unpack_scales_and_zeros(scales_and_zeros):
|
97 |
+
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
|
98 |
+
assert scales_and_zeros.dtype == torch.float
|
99 |
+
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
|
100 |
+
|
101 |
+
|
102 |
+
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
103 |
+
assert groupsize > 1
|
104 |
+
# needed for GPTQ single column quantize
|
105 |
+
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
106 |
+
groupsize = w.shape[-1]
|
107 |
+
|
108 |
+
assert w.shape[-1] % groupsize == 0
|
109 |
+
assert w.dim() == 2
|
110 |
+
|
111 |
+
to_quant = w.reshape(-1, groupsize)
|
112 |
+
assert torch.isnan(to_quant).sum() == 0
|
113 |
+
|
114 |
+
scales = scales.reshape(-1, 1)
|
115 |
+
zeros = zeros.reshape(-1, 1)
|
116 |
+
min_val = zeros - scales * (2 ** (n_bit - 1))
|
117 |
+
max_int = 2**n_bit - 1
|
118 |
+
min_int = 0
|
119 |
+
w_int32 = (
|
120 |
+
to_quant.sub(min_val)
|
121 |
+
.div(scales)
|
122 |
+
.round()
|
123 |
+
.clamp_(min_int, max_int)
|
124 |
+
.to(torch.int32)
|
125 |
+
.reshape_as(w)
|
126 |
+
)
|
127 |
+
|
128 |
+
return w_int32
|
129 |
+
|
130 |
+
|
131 |
+
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
132 |
+
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
133 |
+
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
134 |
+
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
135 |
+
return w_int32, scales_and_zeros
|
136 |
+
|
137 |
+
|
138 |
+
def group_dequantize_tensor_from_qparams(
|
139 |
+
w_int32, scales, zeros, n_bit=4, groupsize=128
|
140 |
+
):
|
141 |
+
assert groupsize > 1
|
142 |
+
# needed for GPTQ single column dequantize
|
143 |
+
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
|
144 |
+
groupsize = w_int32.shape[-1]
|
145 |
+
assert w_int32.shape[-1] % groupsize == 0
|
146 |
+
assert w_int32.dim() == 2
|
147 |
+
|
148 |
+
w_int32_grouped = w_int32.reshape(-1, groupsize)
|
149 |
+
scales = scales.reshape(-1, 1)
|
150 |
+
zeros = zeros.reshape(-1, 1)
|
151 |
+
|
152 |
+
w_dq = (
|
153 |
+
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
|
154 |
+
)
|
155 |
+
return w_dq
|
156 |
+
|
157 |
+
|
158 |
+
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
|
159 |
+
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
|
160 |
+
return group_dequantize_tensor_from_qparams(
|
161 |
+
w_int32, scales, zeros, n_bit, groupsize
|
162 |
+
)
|
163 |
+
|
164 |
+
class QuantHandler:
|
165 |
+
def __init__(self, mod):
|
166 |
+
self.mod = mod
|
167 |
+
|
168 |
+
def create_quantized_state_dict(self) -> "StateDict":
|
169 |
+
pass
|
170 |
+
|
171 |
+
def convert_for_runtime(self) -> "nn.Module":
|
172 |
+
pass
|
173 |
+
|
174 |
+
class GPTQQuantHandler(QuantHandler):
|
175 |
+
"""
|
176 |
+
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
|
177 |
+
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
|
178 |
+
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
|
179 |
+
|
180 |
+
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
|
181 |
+
create_quantized_state_dict. Here is a description of each function.
|
182 |
+
|
183 |
+
get_qparams_func:
|
184 |
+
A function that calculates the quantization qparams for an input tensor.
|
185 |
+
Args:
|
186 |
+
weight: A 2d weight tensor with non-integer dtype.
|
187 |
+
Returns:
|
188 |
+
qparams: it can have any format but will need to be handled by the other defined functions below.
|
189 |
+
|
190 |
+
quantize_func:
|
191 |
+
A function that applies quantization to an input tensor. It should be noted
|
192 |
+
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
|
193 |
+
or a single column.
|
194 |
+
Args:
|
195 |
+
weight: A 2d weight tensor with non-integer dtype.
|
196 |
+
qparams: the output from get_qparams_func
|
197 |
+
Returns:
|
198 |
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
199 |
+
|
200 |
+
|
201 |
+
dequantize_func:
|
202 |
+
A function that dequantizes an input quantized weight tensor. It should be noted
|
203 |
+
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
|
204 |
+
or a single column.
|
205 |
+
Args:
|
206 |
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
207 |
+
qparams: the output from get_qparams_func
|
208 |
+
Returns:
|
209 |
+
weight: A 2d weight tensor with non-integer dtype.
|
210 |
+
|
211 |
+
combine_qparams_list_func:
|
212 |
+
A function that combines several qparams into one qparam.
|
213 |
+
Args:
|
214 |
+
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
|
215 |
+
on a single group from a weight tensor
|
216 |
+
Returns:
|
217 |
+
qparams: an object of the same format as the qparams above.
|
218 |
+
|
219 |
+
skip_layer_func:
|
220 |
+
A function that determines which linear layers should be skipped during GPTQ
|
221 |
+
Args:
|
222 |
+
weight: A 2d weight tensor with non-integer dtype.
|
223 |
+
Returns:
|
224 |
+
skip: boolean indicating whether layer should be skipped
|
225 |
+
|
226 |
+
make_names_and_values_dict_func:
|
227 |
+
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
|
228 |
+
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
|
229 |
+
Args:
|
230 |
+
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
|
231 |
+
qparams: the output from get_qparams_func
|
232 |
+
Returns:
|
233 |
+
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
|
234 |
+
corresponding quantized weights and qparams.
|
235 |
+
"""
|
236 |
+
def __init__(self):
|
237 |
+
assert self.mod is not None
|
238 |
+
assert self.get_qparams_func is not None
|
239 |
+
assert self.quantize_func is not None
|
240 |
+
assert self.dequantize_func is not None
|
241 |
+
assert self.combine_qparams_list_func is not None
|
242 |
+
assert self.make_names_and_values_dict_func is not None
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
|
246 |
+
input_recorder = InputRecorder(
|
247 |
+
model,
|
248 |
+
tokenizer,
|
249 |
+
calibration_seq_length,
|
250 |
+
pad_calibration_inputs,
|
251 |
+
)
|
252 |
+
|
253 |
+
try:
|
254 |
+
lm_eval.tasks.initialize_tasks()
|
255 |
+
except:
|
256 |
+
pass
|
257 |
+
task_dict = get_task_dict(calibration_tasks)
|
258 |
+
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
|
259 |
+
|
260 |
+
evaluate(
|
261 |
+
input_recorder,
|
262 |
+
task_dict,
|
263 |
+
limit=calibration_limit,
|
264 |
+
)
|
265 |
+
inputs = input_recorder.get_recorded_inputs()
|
266 |
+
assert inputs is not None, (
|
267 |
+
f"No inputs were collected, use a task other than {calibration_tasks}, "+
|
268 |
+
f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
|
269 |
+
f"{calibration_seq_length})"
|
270 |
+
)
|
271 |
+
print(f"Obtained {len(inputs[0].values)} calibration samples")
|
272 |
+
return inputs
|
273 |
+
|
274 |
+
@torch.no_grad()
|
275 |
+
def create_quantized_state_dict(
|
276 |
+
self,
|
277 |
+
tokenizer,
|
278 |
+
blocksize,
|
279 |
+
percdamp,
|
280 |
+
groupsize,
|
281 |
+
calibration_tasks,
|
282 |
+
calibration_limit,
|
283 |
+
calibration_seq_length,
|
284 |
+
pad_calibration_inputs,
|
285 |
+
) -> "StateDict":
|
286 |
+
inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
|
287 |
+
print("Tracing model for GPTQ")
|
288 |
+
GPTQ_runner = GenericGPTQRunner(
|
289 |
+
self.mod,
|
290 |
+
inputs,
|
291 |
+
blocksize,
|
292 |
+
percdamp,
|
293 |
+
groupsize,
|
294 |
+
).configure_quantization_mode(
|
295 |
+
self.get_qparams_func,
|
296 |
+
self.quantize_func,
|
297 |
+
self.dequantize_func,
|
298 |
+
self.combine_qparams_list_func,
|
299 |
+
self.make_names_and_values_dict_func,
|
300 |
+
self.skip_layer_func
|
301 |
+
)
|
302 |
+
|
303 |
+
print("Applying GPTQ to weights")
|
304 |
+
GPTQ_runner.run()
|
305 |
+
return GPTQ_runner.get_quantized_state_dict()
|
306 |
+
|
307 |
+
def convert_for_runtime(self) -> "nn.Module":
|
308 |
+
pass
|
309 |
+
|
310 |
+
##### Weight-only int8 per-channel quantized code ######
|
311 |
+
|
312 |
+
def replace_linear_weight_only_int8_per_channel(module):
|
313 |
+
for name, child in module.named_children():
|
314 |
+
if isinstance(child, nn.Linear):
|
315 |
+
setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
|
316 |
+
else:
|
317 |
+
replace_linear_weight_only_int8_per_channel(child)
|
318 |
+
|
319 |
+
class WeightOnlyInt8QuantHandler:
|
320 |
+
def __init__(self, mod):
|
321 |
+
self.mod = mod
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def create_quantized_state_dict(self):
|
325 |
+
cur_state_dict = self.mod.state_dict()
|
326 |
+
for fqn, mod in self.mod.named_modules():
|
327 |
+
if isinstance(mod, torch.nn.Linear):
|
328 |
+
int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
|
329 |
+
cur_state_dict[f"{fqn}.weight"] = int8_weight
|
330 |
+
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
|
331 |
+
|
332 |
+
return cur_state_dict
|
333 |
+
|
334 |
+
def convert_for_runtime(self):
|
335 |
+
replace_linear_weight_only_int8_per_channel(self.mod)
|
336 |
+
return self.mod
|
337 |
+
|
338 |
+
|
339 |
+
class WeightOnlyInt8Linear(torch.nn.Module):
|
340 |
+
__constants__ = ['in_features', 'out_features']
|
341 |
+
in_features: int
|
342 |
+
out_features: int
|
343 |
+
weight: torch.Tensor
|
344 |
+
|
345 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
346 |
+
device=None, dtype=None) -> None:
|
347 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
348 |
+
super().__init__()
|
349 |
+
self.in_features = in_features
|
350 |
+
self.out_features = out_features
|
351 |
+
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
|
352 |
+
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
|
353 |
+
|
354 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
355 |
+
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
|
356 |
+
|
357 |
+
##### weight only int4 per channel groupwise quantized code ######
|
358 |
+
|
359 |
+
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
360 |
+
weight_int32, scales_and_zeros = group_quantize_tensor(
|
361 |
+
weight_bf16, n_bit=4, groupsize=groupsize
|
362 |
+
)
|
363 |
+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
364 |
+
return weight_int4pack, scales_and_zeros
|
365 |
+
|
366 |
+
|
367 |
+
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
368 |
+
origin_x_size = x.size()
|
369 |
+
x = x.reshape(-1, origin_x_size[-1])
|
370 |
+
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
371 |
+
new_shape = origin_x_size[:-1] + (out_features,)
|
372 |
+
c = c.reshape(new_shape)
|
373 |
+
return c
|
374 |
+
|
375 |
+
|
376 |
+
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
|
377 |
+
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
|
378 |
+
|
379 |
+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
|
380 |
+
for name, child in module.named_children():
|
381 |
+
if isinstance(child, nn.Linear):
|
382 |
+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
|
383 |
+
setattr(module, name, WeightOnlyInt4Linear(
|
384 |
+
child.in_features, child.out_features, bias=False,
|
385 |
+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
|
386 |
+
))
|
387 |
+
elif padding:
|
388 |
+
setattr(module, name, WeightOnlyInt4Linear(
|
389 |
+
child.in_features, child.out_features, bias=False,
|
390 |
+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
|
391 |
+
))
|
392 |
+
else:
|
393 |
+
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
|
394 |
+
|
395 |
+
|
396 |
+
class WeightOnlyInt4QuantHandler:
|
397 |
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
398 |
+
self.mod = mod
|
399 |
+
self.groupsize = groupsize
|
400 |
+
self.inner_k_tiles = inner_k_tiles
|
401 |
+
self.padding = padding
|
402 |
+
assert groupsize in [32, 64, 128, 256]
|
403 |
+
assert inner_k_tiles in [2, 4, 8]
|
404 |
+
|
405 |
+
@torch.no_grad()
|
406 |
+
def create_quantized_state_dict(self, use_cuda = True):
|
407 |
+
if use_cuda:
|
408 |
+
device="cuda"
|
409 |
+
else:
|
410 |
+
device="cpu"
|
411 |
+
|
412 |
+
cur_state_dict = self.mod.state_dict()
|
413 |
+
for fqn, mod in self.mod.named_modules():
|
414 |
+
if isinstance(mod, torch.nn.Linear):
|
415 |
+
assert not mod.bias
|
416 |
+
out_features = mod.out_features
|
417 |
+
in_features = mod.in_features
|
418 |
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
419 |
+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
|
420 |
+
|
421 |
+
weight = mod.weight.data
|
422 |
+
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
|
423 |
+
if self.padding:
|
424 |
+
from model import find_multiple
|
425 |
+
import torch.nn.functional as F
|
426 |
+
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
|
427 |
+
padded_in_features = find_multiple(in_features, 1024)
|
428 |
+
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
|
429 |
+
else:
|
430 |
+
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
|
431 |
+
"and that groupsize and inner_k_tiles*16 evenly divide into it")
|
432 |
+
continue
|
433 |
+
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
434 |
+
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
|
435 |
+
)
|
436 |
+
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
|
437 |
+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
|
438 |
+
|
439 |
+
return cur_state_dict
|
440 |
+
|
441 |
+
def convert_for_runtime(self):
|
442 |
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
443 |
+
return self.mod
|
444 |
+
|
445 |
+
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
|
446 |
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
447 |
+
from model import find_multiple
|
448 |
+
self.mod = mod
|
449 |
+
self.groupsize = groupsize
|
450 |
+
self.inner_k_tiles = inner_k_tiles
|
451 |
+
self.padding = padding
|
452 |
+
self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
|
453 |
+
self.quantize_func = lambda w, qparams: \
|
454 |
+
group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
|
455 |
+
self.dequantize_func = lambda q, qparams: \
|
456 |
+
group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
|
457 |
+
self.combine_qparams_list_func = lambda qparams_list: \
|
458 |
+
[torch.cat(x, dim=1) for x in zip(*qparams_list)]
|
459 |
+
# skip unless padding=True or its correctly sized
|
460 |
+
self.skip_layer_func = lambda linear_weight: not (
|
461 |
+
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
|
462 |
+
)
|
463 |
+
# we need to do the padding here, both for q and the qparams if necessary
|
464 |
+
def make_names_and_values_dict_func(q, qparams):
|
465 |
+
k = q.shape[1]
|
466 |
+
new_k = find_multiple(k, 1024)
|
467 |
+
# how much we need to pad the weight
|
468 |
+
delta_k = new_k - q.shape[1]
|
469 |
+
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
|
470 |
+
scales_and_zeros = pack_scales_and_zeros(*qparams)
|
471 |
+
# how many new groups we need for padded weight
|
472 |
+
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
|
473 |
+
final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
|
474 |
+
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
|
475 |
+
self.make_names_and_values_dict_func = make_names_and_values_dict_func
|
476 |
+
super().__init__()
|
477 |
+
|
478 |
+
|
479 |
+
def convert_for_runtime(self):
|
480 |
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
481 |
+
return self.mod
|
482 |
+
|
483 |
+
class WeightOnlyInt4Linear(torch.nn.Module):
|
484 |
+
__constants__ = ['in_features', 'out_features']
|
485 |
+
in_features: int
|
486 |
+
out_features: int
|
487 |
+
weight: torch.Tensor
|
488 |
+
|
489 |
+
def __init__(
|
490 |
+
self, in_features: int, out_features: int,
|
491 |
+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
|
492 |
+
) -> None:
|
493 |
+
super().__init__()
|
494 |
+
self.padding = padding
|
495 |
+
if padding:
|
496 |
+
from model import find_multiple
|
497 |
+
self.origin_in_features = in_features
|
498 |
+
in_features = find_multiple(in_features, 1024)
|
499 |
+
|
500 |
+
self.in_features = in_features
|
501 |
+
self.out_features = out_features
|
502 |
+
assert not bias, "require bias=False"
|
503 |
+
self.groupsize = groupsize
|
504 |
+
self.inner_k_tiles = inner_k_tiles
|
505 |
+
|
506 |
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
507 |
+
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
508 |
+
self.register_buffer(
|
509 |
+
"weight",
|
510 |
+
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
511 |
+
)
|
512 |
+
self.register_buffer(
|
513 |
+
"scales_and_zeros",
|
514 |
+
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
515 |
+
)
|
516 |
+
|
517 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
518 |
+
input = input.to(torch.bfloat16)
|
519 |
+
if self.padding:
|
520 |
+
import torch.nn.functional as F
|
521 |
+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
|
522 |
+
return linear_forward_int4(
|
523 |
+
input,
|
524 |
+
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
525 |
+
)
|
526 |
+
|
527 |
+
|
528 |
+
def quantize(
|
529 |
+
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
|
530 |
+
mode: str = 'int8',
|
531 |
+
# following arguments only available when setting int4 quantization.
|
532 |
+
groupsize: int = 128,
|
533 |
+
# following arguments only used for GPTQ
|
534 |
+
calibration_tasks: list = ["hellaswag"],
|
535 |
+
calibration_limit: int = 1000,
|
536 |
+
calibration_seq_length: int = 100,
|
537 |
+
pad_calibration_inputs: bool = False,
|
538 |
+
percdamp: float = .01,
|
539 |
+
blocksize: int = 128,
|
540 |
+
label: str = '',
|
541 |
+
) -> None:
|
542 |
+
assert checkpoint_path.is_file(), checkpoint_path
|
543 |
+
|
544 |
+
device = 'cpu'
|
545 |
+
precision = torch.bfloat16
|
546 |
+
|
547 |
+
print("Loading model ...")
|
548 |
+
t0 = time.time()
|
549 |
+
|
550 |
+
with torch.device('meta'):
|
551 |
+
model = Transformer.from_name(checkpoint_path.parent.name)
|
552 |
+
|
553 |
+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
554 |
+
model.load_state_dict(checkpoint, assign=True)
|
555 |
+
model = model.to(dtype=precision, device=device)
|
556 |
+
|
557 |
+
if mode == 'int8':
|
558 |
+
print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
|
559 |
+
quant_handler = WeightOnlyInt8QuantHandler(model)
|
560 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
561 |
+
|
562 |
+
dir_name = checkpoint_path.parent
|
563 |
+
base_name = checkpoint_path.name
|
564 |
+
new_base_name = base_name.replace('.pth', f'{label}int8.pth')
|
565 |
+
|
566 |
+
elif mode == 'int4':
|
567 |
+
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
|
568 |
+
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
569 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
570 |
+
|
571 |
+
dir_name = checkpoint_path.parent
|
572 |
+
base_name = checkpoint_path.name
|
573 |
+
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
|
574 |
+
|
575 |
+
elif mode == 'int4-gptq':
|
576 |
+
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
|
577 |
+
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
|
578 |
+
|
579 |
+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
|
580 |
+
assert tokenizer_path.is_file(), str(tokenizer_path)
|
581 |
+
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
|
582 |
+
|
583 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict(
|
584 |
+
tokenizer,
|
585 |
+
blocksize,
|
586 |
+
percdamp,
|
587 |
+
groupsize,
|
588 |
+
calibration_tasks,
|
589 |
+
calibration_limit,
|
590 |
+
calibration_seq_length,
|
591 |
+
pad_calibration_inputs
|
592 |
+
)
|
593 |
+
|
594 |
+
dir_name = checkpoint_path.parent
|
595 |
+
base_name = checkpoint_path.name
|
596 |
+
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
|
597 |
+
else:
|
598 |
+
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
|
599 |
+
|
600 |
+
quantize_path = dir_name / new_base_name
|
601 |
+
print(f"Writing quantized weights to {quantize_path}")
|
602 |
+
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
603 |
+
torch.save(quantized_state_dict, quantize_path)
|
604 |
+
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
605 |
+
return
|
606 |
+
|
607 |
+
if __name__ == '__main__':
|
608 |
+
import argparse
|
609 |
+
parser = argparse.ArgumentParser(description='Quantize a model.')
|
610 |
+
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
|
611 |
+
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
|
612 |
+
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
|
613 |
+
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
|
614 |
+
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
|
615 |
+
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
|
616 |
+
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
|
617 |
+
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
|
618 |
+
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
|
619 |
+
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
|
620 |
+
|
621 |
+
args = parser.parse_args()
|
622 |
+
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
|
modules/hifigan/__pycache__/f0_predictor.cpython-310.pyc
ADDED
Binary file (1.33 kB). View file
|
|
modules/hifigan/__pycache__/generator.cpython-310.pyc
ADDED
Binary file (13.3 kB). View file
|
|
modules/hifigan/f0_predictor.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn.utils import weight_norm
|
17 |
+
|
18 |
+
|
19 |
+
class ConvRNNF0Predictor(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
num_class: int = 1,
|
22 |
+
in_channels: int = 80,
|
23 |
+
cond_channels: int = 512
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.num_class = num_class
|
28 |
+
self.condnet = nn.Sequential(
|
29 |
+
weight_norm(
|
30 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
31 |
+
),
|
32 |
+
nn.ELU(),
|
33 |
+
weight_norm(
|
34 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
35 |
+
),
|
36 |
+
nn.ELU(),
|
37 |
+
weight_norm(
|
38 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
39 |
+
),
|
40 |
+
nn.ELU(),
|
41 |
+
weight_norm(
|
42 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
43 |
+
),
|
44 |
+
nn.ELU(),
|
45 |
+
weight_norm(
|
46 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
47 |
+
),
|
48 |
+
nn.ELU(),
|
49 |
+
)
|
50 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
x = self.condnet(x)
|
54 |
+
x = x.transpose(1, 2)
|
55 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
modules/hifigan/generator.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""HIFI-GAN"""
|
16 |
+
|
17 |
+
import typing as tp
|
18 |
+
import numpy as np
|
19 |
+
from scipy.signal import get_window
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch.nn import Conv1d
|
24 |
+
from torch.nn import ConvTranspose1d
|
25 |
+
from torch.nn.utils import remove_weight_norm
|
26 |
+
from torch.nn.utils import weight_norm
|
27 |
+
from torch.distributions.uniform import Uniform
|
28 |
+
|
29 |
+
from torch import sin
|
30 |
+
from torch.nn.parameter import Parameter
|
31 |
+
|
32 |
+
|
33 |
+
"""hifigan based generator implementation.
|
34 |
+
|
35 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
36 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
37 |
+
https://github.com/NVIDIA/BigVGAN
|
38 |
+
|
39 |
+
"""
|
40 |
+
class Snake(nn.Module):
|
41 |
+
'''
|
42 |
+
Implementation of a sine-based periodic activation function
|
43 |
+
Shape:
|
44 |
+
- Input: (B, C, T)
|
45 |
+
- Output: (B, C, T), same shape as the input
|
46 |
+
Parameters:
|
47 |
+
- alpha - trainable parameter
|
48 |
+
References:
|
49 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
50 |
+
https://arxiv.org/abs/2006.08195
|
51 |
+
Examples:
|
52 |
+
>>> a1 = snake(256)
|
53 |
+
>>> x = torch.randn(256)
|
54 |
+
>>> x = a1(x)
|
55 |
+
'''
|
56 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
57 |
+
'''
|
58 |
+
Initialization.
|
59 |
+
INPUT:
|
60 |
+
- in_features: shape of the input
|
61 |
+
- alpha: trainable parameter
|
62 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
63 |
+
alpha will be trained along with the rest of your model.
|
64 |
+
'''
|
65 |
+
super(Snake, self).__init__()
|
66 |
+
self.in_features = in_features
|
67 |
+
|
68 |
+
# initialize alpha
|
69 |
+
self.alpha_logscale = alpha_logscale
|
70 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
71 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
72 |
+
else: # linear scale alphas initialized to ones
|
73 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
74 |
+
|
75 |
+
self.alpha.requires_grad = alpha_trainable
|
76 |
+
|
77 |
+
self.no_div_by_zero = 0.000000001
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
'''
|
81 |
+
Forward pass of the function.
|
82 |
+
Applies the function to the input elementwise.
|
83 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
84 |
+
'''
|
85 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
86 |
+
if self.alpha_logscale:
|
87 |
+
alpha = torch.exp(alpha)
|
88 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
89 |
+
|
90 |
+
return x
|
91 |
+
|
92 |
+
def get_padding(kernel_size, dilation=1):
|
93 |
+
return int((kernel_size * dilation - dilation) / 2)
|
94 |
+
|
95 |
+
|
96 |
+
def init_weights(m, mean=0.0, std=0.01):
|
97 |
+
classname = m.__class__.__name__
|
98 |
+
if classname.find("Conv") != -1:
|
99 |
+
m.weight.data.normal_(mean, std)
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
class ResBlock(torch.nn.Module):
|
104 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
channels: int = 512,
|
108 |
+
kernel_size: int = 3,
|
109 |
+
dilations: tp.List[int] = [1, 3, 5],
|
110 |
+
):
|
111 |
+
super(ResBlock, self).__init__()
|
112 |
+
self.convs1 = nn.ModuleList()
|
113 |
+
self.convs2 = nn.ModuleList()
|
114 |
+
|
115 |
+
for dilation in dilations:
|
116 |
+
self.convs1.append(
|
117 |
+
weight_norm(
|
118 |
+
Conv1d(
|
119 |
+
channels,
|
120 |
+
channels,
|
121 |
+
kernel_size,
|
122 |
+
1,
|
123 |
+
dilation=dilation,
|
124 |
+
padding=get_padding(kernel_size, dilation)
|
125 |
+
)
|
126 |
+
)
|
127 |
+
)
|
128 |
+
self.convs2.append(
|
129 |
+
weight_norm(
|
130 |
+
Conv1d(
|
131 |
+
channels,
|
132 |
+
channels,
|
133 |
+
kernel_size,
|
134 |
+
1,
|
135 |
+
dilation=1,
|
136 |
+
padding=get_padding(kernel_size, 1)
|
137 |
+
)
|
138 |
+
)
|
139 |
+
)
|
140 |
+
self.convs1.apply(init_weights)
|
141 |
+
self.convs2.apply(init_weights)
|
142 |
+
self.activations1 = nn.ModuleList([
|
143 |
+
Snake(channels, alpha_logscale=False)
|
144 |
+
for _ in range(len(self.convs1))
|
145 |
+
])
|
146 |
+
self.activations2 = nn.ModuleList([
|
147 |
+
Snake(channels, alpha_logscale=False)
|
148 |
+
for _ in range(len(self.convs2))
|
149 |
+
])
|
150 |
+
|
151 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
152 |
+
for idx in range(len(self.convs1)):
|
153 |
+
xt = self.activations1[idx](x)
|
154 |
+
xt = self.convs1[idx](xt)
|
155 |
+
xt = self.activations2[idx](xt)
|
156 |
+
xt = self.convs2[idx](xt)
|
157 |
+
x = xt + x
|
158 |
+
return x
|
159 |
+
|
160 |
+
def remove_weight_norm(self):
|
161 |
+
for idx in range(len(self.convs1)):
|
162 |
+
remove_weight_norm(self.convs1[idx])
|
163 |
+
remove_weight_norm(self.convs2[idx])
|
164 |
+
|
165 |
+
class SineGen(torch.nn.Module):
|
166 |
+
""" Definition of sine generator
|
167 |
+
SineGen(samp_rate, harmonic_num = 0,
|
168 |
+
sine_amp = 0.1, noise_std = 0.003,
|
169 |
+
voiced_threshold = 0,
|
170 |
+
flag_for_pulse=False)
|
171 |
+
samp_rate: sampling rate in Hz
|
172 |
+
harmonic_num: number of harmonic overtones (default 0)
|
173 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
174 |
+
noise_std: std of Gaussian noise (default 0.003)
|
175 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
176 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
177 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
178 |
+
segment is always sin(np.pi) or cos(0)
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
182 |
+
sine_amp=0.1, noise_std=0.003,
|
183 |
+
voiced_threshold=0):
|
184 |
+
super(SineGen, self).__init__()
|
185 |
+
self.sine_amp = sine_amp
|
186 |
+
self.noise_std = noise_std
|
187 |
+
self.harmonic_num = harmonic_num
|
188 |
+
self.sampling_rate = samp_rate
|
189 |
+
self.voiced_threshold = voiced_threshold
|
190 |
+
|
191 |
+
def _f02uv(self, f0):
|
192 |
+
# generate uv signal
|
193 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
194 |
+
return uv
|
195 |
+
|
196 |
+
@torch.no_grad()
|
197 |
+
def forward(self, f0):
|
198 |
+
"""
|
199 |
+
:param f0: [B, 1, sample_len], Hz
|
200 |
+
:return: [B, 1, sample_len]
|
201 |
+
"""
|
202 |
+
|
203 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
204 |
+
for i in range(self.harmonic_num + 1):
|
205 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
206 |
+
|
207 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
208 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
209 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
210 |
+
phase_vec[:, 0, :] = 0
|
211 |
+
|
212 |
+
# generate sine waveforms
|
213 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
214 |
+
|
215 |
+
# generate uv signal
|
216 |
+
uv = self._f02uv(f0)
|
217 |
+
|
218 |
+
# noise: for unvoiced should be similar to sine_amp
|
219 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
220 |
+
# . for voiced regions is self.noise_std
|
221 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
222 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
223 |
+
|
224 |
+
# first: set the unvoiced part to 0 by uv
|
225 |
+
# then: additive noise
|
226 |
+
sine_waves = sine_waves * uv + noise
|
227 |
+
return sine_waves, uv, noise
|
228 |
+
|
229 |
+
|
230 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
231 |
+
""" SourceModule for hn-nsf
|
232 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
233 |
+
add_noise_std=0.003, voiced_threshod=0)
|
234 |
+
sampling_rate: sampling_rate in Hz
|
235 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
236 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
237 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
238 |
+
note that amplitude of noise in unvoiced is decided
|
239 |
+
by sine_amp
|
240 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
241 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
242 |
+
F0_sampled (batchsize, length, 1)
|
243 |
+
Sine_source (batchsize, length, 1)
|
244 |
+
noise_source (batchsize, length 1)
|
245 |
+
uv (batchsize, length, 1)
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
249 |
+
add_noise_std=0.003, voiced_threshod=0):
|
250 |
+
super(SourceModuleHnNSF, self).__init__()
|
251 |
+
|
252 |
+
self.sine_amp = sine_amp
|
253 |
+
self.noise_std = add_noise_std
|
254 |
+
|
255 |
+
# to produce sine waveforms
|
256 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
257 |
+
sine_amp, add_noise_std, voiced_threshod)
|
258 |
+
|
259 |
+
# to merge source harmonics into a single excitation
|
260 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
261 |
+
self.l_tanh = torch.nn.Tanh()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""
|
265 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
266 |
+
F0_sampled (batchsize, length, 1)
|
267 |
+
Sine_source (batchsize, length, 1)
|
268 |
+
noise_source (batchsize, length 1)
|
269 |
+
"""
|
270 |
+
# source for harmonic branch
|
271 |
+
with torch.no_grad():
|
272 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
273 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
274 |
+
uv = uv.transpose(1, 2)
|
275 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
276 |
+
|
277 |
+
# source for noise branch, in the same shape as uv
|
278 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
279 |
+
return sine_merge, noise, uv
|
280 |
+
|
281 |
+
|
282 |
+
class HiFTGenerator(nn.Module):
|
283 |
+
"""
|
284 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
285 |
+
https://arxiv.org/abs/2309.09493
|
286 |
+
"""
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
in_channels: int = 80,
|
290 |
+
base_channels: int = 512,
|
291 |
+
nb_harmonics: int = 8,
|
292 |
+
sampling_rate: int = 22050,
|
293 |
+
nsf_alpha: float = 0.1,
|
294 |
+
nsf_sigma: float = 0.003,
|
295 |
+
nsf_voiced_threshold: float = 10,
|
296 |
+
upsample_rates: tp.List[int] = [8, 8],
|
297 |
+
upsample_kernel_sizes: tp.List[int] = [16, 16],
|
298 |
+
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
299 |
+
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
300 |
+
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
301 |
+
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
302 |
+
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
303 |
+
lrelu_slope: float = 0.1,
|
304 |
+
audio_limit: float = 0.99,
|
305 |
+
f0_predictor: torch.nn.Module = None,
|
306 |
+
):
|
307 |
+
super(HiFTGenerator, self).__init__()
|
308 |
+
|
309 |
+
self.out_channels = 1
|
310 |
+
self.nb_harmonics = nb_harmonics
|
311 |
+
self.sampling_rate = sampling_rate
|
312 |
+
self.istft_params = istft_params
|
313 |
+
self.lrelu_slope = lrelu_slope
|
314 |
+
self.audio_limit = audio_limit
|
315 |
+
|
316 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
317 |
+
self.num_upsamples = len(upsample_rates)
|
318 |
+
self.m_source = SourceModuleHnNSF(
|
319 |
+
sampling_rate=sampling_rate,
|
320 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
321 |
+
harmonic_num=nb_harmonics,
|
322 |
+
sine_amp=nsf_alpha,
|
323 |
+
add_noise_std=nsf_sigma,
|
324 |
+
voiced_threshod=nsf_voiced_threshold)
|
325 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
326 |
+
|
327 |
+
self.conv_pre = weight_norm(
|
328 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
329 |
+
)
|
330 |
+
|
331 |
+
# Up
|
332 |
+
self.ups = nn.ModuleList()
|
333 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
334 |
+
self.ups.append(
|
335 |
+
weight_norm(
|
336 |
+
ConvTranspose1d(
|
337 |
+
base_channels // (2**i),
|
338 |
+
base_channels // (2**(i + 1)),
|
339 |
+
k,
|
340 |
+
u,
|
341 |
+
padding=(k - u) // 2,
|
342 |
+
)
|
343 |
+
)
|
344 |
+
)
|
345 |
+
|
346 |
+
# Down
|
347 |
+
self.source_downs = nn.ModuleList()
|
348 |
+
self.source_resblocks = nn.ModuleList()
|
349 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
350 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
351 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
|
352 |
+
source_resblock_dilation_sizes)):
|
353 |
+
if u == 1:
|
354 |
+
self.source_downs.append(
|
355 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
self.source_downs.append(
|
359 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
360 |
+
)
|
361 |
+
|
362 |
+
self.source_resblocks.append(
|
363 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
364 |
+
)
|
365 |
+
|
366 |
+
self.resblocks = nn.ModuleList()
|
367 |
+
for i in range(len(self.ups)):
|
368 |
+
ch = base_channels // (2**(i + 1))
|
369 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
370 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
371 |
+
|
372 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
373 |
+
self.ups.apply(init_weights)
|
374 |
+
self.conv_post.apply(init_weights)
|
375 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
376 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
377 |
+
self.f0_predictor = f0_predictor
|
378 |
+
|
379 |
+
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
380 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
381 |
+
|
382 |
+
har_source, _, _ = self.m_source(f0)
|
383 |
+
return har_source.transpose(1, 2)
|
384 |
+
|
385 |
+
def _stft(self, x):
|
386 |
+
spec = torch.stft(
|
387 |
+
x,
|
388 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
389 |
+
return_complex=True)
|
390 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
391 |
+
return spec[..., 0], spec[..., 1]
|
392 |
+
|
393 |
+
def _istft(self, magnitude, phase):
|
394 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
395 |
+
real = magnitude * torch.cos(phase)
|
396 |
+
img = magnitude * torch.sin(phase)
|
397 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
398 |
+
return inverse_transform
|
399 |
+
|
400 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
401 |
+
f0 = self.f0_predictor(x)
|
402 |
+
s = self._f02source(f0)
|
403 |
+
|
404 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
405 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
406 |
+
|
407 |
+
x = self.conv_pre(x)
|
408 |
+
for i in range(self.num_upsamples):
|
409 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
410 |
+
x = self.ups[i](x)
|
411 |
+
|
412 |
+
if i == self.num_upsamples - 1:
|
413 |
+
x = self.reflection_pad(x)
|
414 |
+
|
415 |
+
# fusion
|
416 |
+
si = self.source_downs[i](s_stft)
|
417 |
+
si = self.source_resblocks[i](si)
|
418 |
+
x = x + si
|
419 |
+
|
420 |
+
xs = None
|
421 |
+
for j in range(self.num_kernels):
|
422 |
+
if xs is None:
|
423 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
424 |
+
else:
|
425 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
426 |
+
x = xs / self.num_kernels
|
427 |
+
|
428 |
+
x = F.leaky_relu(x)
|
429 |
+
x = self.conv_post(x)
|
430 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
431 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
432 |
+
|
433 |
+
x = self._istft(magnitude, phase)
|
434 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
435 |
+
return x
|
436 |
+
|
437 |
+
def remove_weight_norm(self):
|
438 |
+
print('Removing weight norm...')
|
439 |
+
for l in self.ups:
|
440 |
+
remove_weight_norm(l)
|
441 |
+
for l in self.resblocks:
|
442 |
+
l.remove_weight_norm()
|
443 |
+
remove_weight_norm(self.conv_pre)
|
444 |
+
remove_weight_norm(self.conv_post)
|
445 |
+
self.source_module.remove_weight_norm()
|
446 |
+
for l in self.source_downs:
|
447 |
+
remove_weight_norm(l)
|
448 |
+
for l in self.source_resblocks:
|
449 |
+
l.remove_weight_norm()
|
450 |
+
|
451 |
+
@torch.inference_mode()
|
452 |
+
def inference(self, mel: torch.Tensor) -> torch.Tensor:
|
453 |
+
return self.forward(x=mel)
|
modules/layers.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from typing import Optional, Any
|
5 |
+
from torch import Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
import torchaudio.functional as audio_F
|
9 |
+
|
10 |
+
import random
|
11 |
+
random.seed(0)
|
12 |
+
|
13 |
+
|
14 |
+
def _get_activation_fn(activ):
|
15 |
+
if activ == 'relu':
|
16 |
+
return nn.ReLU()
|
17 |
+
elif activ == 'lrelu':
|
18 |
+
return nn.LeakyReLU(0.2)
|
19 |
+
elif activ == 'swish':
|
20 |
+
return lambda x: x*torch.sigmoid(x)
|
21 |
+
else:
|
22 |
+
raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
|
23 |
+
|
24 |
+
class LinearNorm(torch.nn.Module):
|
25 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
26 |
+
super(LinearNorm, self).__init__()
|
27 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
28 |
+
|
29 |
+
torch.nn.init.xavier_uniform_(
|
30 |
+
self.linear_layer.weight,
|
31 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.linear_layer(x)
|
35 |
+
|
36 |
+
|
37 |
+
class ConvNorm(torch.nn.Module):
|
38 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
39 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
|
40 |
+
super(ConvNorm, self).__init__()
|
41 |
+
if padding is None:
|
42 |
+
assert(kernel_size % 2 == 1)
|
43 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
44 |
+
|
45 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
46 |
+
kernel_size=kernel_size, stride=stride,
|
47 |
+
padding=padding, dilation=dilation,
|
48 |
+
bias=bias)
|
49 |
+
|
50 |
+
torch.nn.init.xavier_uniform_(
|
51 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
52 |
+
|
53 |
+
def forward(self, signal):
|
54 |
+
conv_signal = self.conv(signal)
|
55 |
+
return conv_signal
|
56 |
+
|
57 |
+
class CausualConv(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
|
59 |
+
super(CausualConv, self).__init__()
|
60 |
+
if padding is None:
|
61 |
+
assert(kernel_size % 2 == 1)
|
62 |
+
padding = int(dilation * (kernel_size - 1) / 2) * 2
|
63 |
+
else:
|
64 |
+
self.padding = padding * 2
|
65 |
+
self.conv = nn.Conv1d(in_channels, out_channels,
|
66 |
+
kernel_size=kernel_size, stride=stride,
|
67 |
+
padding=self.padding,
|
68 |
+
dilation=dilation,
|
69 |
+
bias=bias)
|
70 |
+
|
71 |
+
torch.nn.init.xavier_uniform_(
|
72 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
x = self.conv(x)
|
76 |
+
x = x[:, :, :-self.padding]
|
77 |
+
return x
|
78 |
+
|
79 |
+
class CausualBlock(nn.Module):
|
80 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
|
81 |
+
super(CausualBlock, self).__init__()
|
82 |
+
self.blocks = nn.ModuleList([
|
83 |
+
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
|
84 |
+
for i in range(n_conv)])
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
for block in self.blocks:
|
88 |
+
res = x
|
89 |
+
x = block(x)
|
90 |
+
x += res
|
91 |
+
return x
|
92 |
+
|
93 |
+
def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
|
94 |
+
layers = [
|
95 |
+
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
|
96 |
+
_get_activation_fn(activ),
|
97 |
+
nn.BatchNorm1d(hidden_dim),
|
98 |
+
nn.Dropout(p=dropout_p),
|
99 |
+
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
100 |
+
_get_activation_fn(activ),
|
101 |
+
nn.Dropout(p=dropout_p)
|
102 |
+
]
|
103 |
+
return nn.Sequential(*layers)
|
104 |
+
|
105 |
+
class ConvBlock(nn.Module):
|
106 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
|
107 |
+
super().__init__()
|
108 |
+
self._n_groups = 8
|
109 |
+
self.blocks = nn.ModuleList([
|
110 |
+
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
|
111 |
+
for i in range(n_conv)])
|
112 |
+
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
for block in self.blocks:
|
116 |
+
res = x
|
117 |
+
x = block(x)
|
118 |
+
x += res
|
119 |
+
return x
|
120 |
+
|
121 |
+
def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
|
122 |
+
layers = [
|
123 |
+
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
|
124 |
+
_get_activation_fn(activ),
|
125 |
+
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
126 |
+
nn.Dropout(p=dropout_p),
|
127 |
+
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
128 |
+
_get_activation_fn(activ),
|
129 |
+
nn.Dropout(p=dropout_p)
|
130 |
+
]
|
131 |
+
return nn.Sequential(*layers)
|
132 |
+
|
133 |
+
class LocationLayer(nn.Module):
|
134 |
+
def __init__(self, attention_n_filters, attention_kernel_size,
|
135 |
+
attention_dim):
|
136 |
+
super(LocationLayer, self).__init__()
|
137 |
+
padding = int((attention_kernel_size - 1) / 2)
|
138 |
+
self.location_conv = ConvNorm(2, attention_n_filters,
|
139 |
+
kernel_size=attention_kernel_size,
|
140 |
+
padding=padding, bias=False, stride=1,
|
141 |
+
dilation=1)
|
142 |
+
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
|
143 |
+
bias=False, w_init_gain='tanh')
|
144 |
+
|
145 |
+
def forward(self, attention_weights_cat):
|
146 |
+
processed_attention = self.location_conv(attention_weights_cat)
|
147 |
+
processed_attention = processed_attention.transpose(1, 2)
|
148 |
+
processed_attention = self.location_dense(processed_attention)
|
149 |
+
return processed_attention
|
150 |
+
|
151 |
+
|
152 |
+
class Attention(nn.Module):
|
153 |
+
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
154 |
+
attention_location_n_filters, attention_location_kernel_size):
|
155 |
+
super(Attention, self).__init__()
|
156 |
+
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
157 |
+
bias=False, w_init_gain='tanh')
|
158 |
+
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
|
159 |
+
w_init_gain='tanh')
|
160 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
161 |
+
self.location_layer = LocationLayer(attention_location_n_filters,
|
162 |
+
attention_location_kernel_size,
|
163 |
+
attention_dim)
|
164 |
+
self.score_mask_value = -float("inf")
|
165 |
+
|
166 |
+
def get_alignment_energies(self, query, processed_memory,
|
167 |
+
attention_weights_cat):
|
168 |
+
"""
|
169 |
+
PARAMS
|
170 |
+
------
|
171 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
172 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
173 |
+
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
174 |
+
RETURNS
|
175 |
+
-------
|
176 |
+
alignment (batch, max_time)
|
177 |
+
"""
|
178 |
+
|
179 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
180 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
181 |
+
energies = self.v(torch.tanh(
|
182 |
+
processed_query + processed_attention_weights + processed_memory))
|
183 |
+
|
184 |
+
energies = energies.squeeze(-1)
|
185 |
+
return energies
|
186 |
+
|
187 |
+
def forward(self, attention_hidden_state, memory, processed_memory,
|
188 |
+
attention_weights_cat, mask):
|
189 |
+
"""
|
190 |
+
PARAMS
|
191 |
+
------
|
192 |
+
attention_hidden_state: attention rnn last output
|
193 |
+
memory: encoder outputs
|
194 |
+
processed_memory: processed encoder outputs
|
195 |
+
attention_weights_cat: previous and cummulative attention weights
|
196 |
+
mask: binary mask for padded data
|
197 |
+
"""
|
198 |
+
alignment = self.get_alignment_energies(
|
199 |
+
attention_hidden_state, processed_memory, attention_weights_cat)
|
200 |
+
|
201 |
+
if mask is not None:
|
202 |
+
alignment.data.masked_fill_(mask, self.score_mask_value)
|
203 |
+
|
204 |
+
attention_weights = F.softmax(alignment, dim=1)
|
205 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
206 |
+
attention_context = attention_context.squeeze(1)
|
207 |
+
|
208 |
+
return attention_context, attention_weights
|
209 |
+
|
210 |
+
|
211 |
+
class ForwardAttentionV2(nn.Module):
|
212 |
+
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
213 |
+
attention_location_n_filters, attention_location_kernel_size):
|
214 |
+
super(ForwardAttentionV2, self).__init__()
|
215 |
+
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
216 |
+
bias=False, w_init_gain='tanh')
|
217 |
+
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
|
218 |
+
w_init_gain='tanh')
|
219 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
220 |
+
self.location_layer = LocationLayer(attention_location_n_filters,
|
221 |
+
attention_location_kernel_size,
|
222 |
+
attention_dim)
|
223 |
+
self.score_mask_value = -float(1e20)
|
224 |
+
|
225 |
+
def get_alignment_energies(self, query, processed_memory,
|
226 |
+
attention_weights_cat):
|
227 |
+
"""
|
228 |
+
PARAMS
|
229 |
+
------
|
230 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
231 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
232 |
+
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
|
233 |
+
RETURNS
|
234 |
+
-------
|
235 |
+
alignment (batch, max_time)
|
236 |
+
"""
|
237 |
+
|
238 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
239 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
240 |
+
energies = self.v(torch.tanh(
|
241 |
+
processed_query + processed_attention_weights + processed_memory))
|
242 |
+
|
243 |
+
energies = energies.squeeze(-1)
|
244 |
+
return energies
|
245 |
+
|
246 |
+
def forward(self, attention_hidden_state, memory, processed_memory,
|
247 |
+
attention_weights_cat, mask, log_alpha):
|
248 |
+
"""
|
249 |
+
PARAMS
|
250 |
+
------
|
251 |
+
attention_hidden_state: attention rnn last output
|
252 |
+
memory: encoder outputs
|
253 |
+
processed_memory: processed encoder outputs
|
254 |
+
attention_weights_cat: previous and cummulative attention weights
|
255 |
+
mask: binary mask for padded data
|
256 |
+
"""
|
257 |
+
log_energy = self.get_alignment_energies(
|
258 |
+
attention_hidden_state, processed_memory, attention_weights_cat)
|
259 |
+
|
260 |
+
#log_energy =
|
261 |
+
|
262 |
+
if mask is not None:
|
263 |
+
log_energy.data.masked_fill_(mask, self.score_mask_value)
|
264 |
+
|
265 |
+
#attention_weights = F.softmax(alignment, dim=1)
|
266 |
+
|
267 |
+
#content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
|
268 |
+
#log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
|
269 |
+
|
270 |
+
#log_total_score = log_alpha + content_score
|
271 |
+
|
272 |
+
#previous_attention_weights = attention_weights_cat[:,0,:]
|
273 |
+
|
274 |
+
log_alpha_shift_padded = []
|
275 |
+
max_time = log_energy.size(1)
|
276 |
+
for sft in range(2):
|
277 |
+
shifted = log_alpha[:,:max_time-sft]
|
278 |
+
shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
|
279 |
+
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
|
280 |
+
|
281 |
+
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
|
282 |
+
|
283 |
+
log_alpha_new = biased + log_energy
|
284 |
+
|
285 |
+
attention_weights = F.softmax(log_alpha_new, dim=1)
|
286 |
+
|
287 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
288 |
+
attention_context = attention_context.squeeze(1)
|
289 |
+
|
290 |
+
return attention_context, attention_weights, log_alpha_new
|
291 |
+
|
292 |
+
|
293 |
+
class PhaseShuffle2d(nn.Module):
|
294 |
+
def __init__(self, n=2):
|
295 |
+
super(PhaseShuffle2d, self).__init__()
|
296 |
+
self.n = n
|
297 |
+
self.random = random.Random(1)
|
298 |
+
|
299 |
+
def forward(self, x, move=None):
|
300 |
+
# x.size = (B, C, M, L)
|
301 |
+
if move is None:
|
302 |
+
move = self.random.randint(-self.n, self.n)
|
303 |
+
|
304 |
+
if move == 0:
|
305 |
+
return x
|
306 |
+
else:
|
307 |
+
left = x[:, :, :, :move]
|
308 |
+
right = x[:, :, :, move:]
|
309 |
+
shuffled = torch.cat([right, left], dim=3)
|
310 |
+
return shuffled
|
311 |
+
|
312 |
+
class PhaseShuffle1d(nn.Module):
|
313 |
+
def __init__(self, n=2):
|
314 |
+
super(PhaseShuffle1d, self).__init__()
|
315 |
+
self.n = n
|
316 |
+
self.random = random.Random(1)
|
317 |
+
|
318 |
+
def forward(self, x, move=None):
|
319 |
+
# x.size = (B, C, M, L)
|
320 |
+
if move is None:
|
321 |
+
move = self.random.randint(-self.n, self.n)
|
322 |
+
|
323 |
+
if move == 0:
|
324 |
+
return x
|
325 |
+
else:
|
326 |
+
left = x[:, :, :move]
|
327 |
+
right = x[:, :, move:]
|
328 |
+
shuffled = torch.cat([right, left], dim=2)
|
329 |
+
|
330 |
+
return shuffled
|
331 |
+
|
332 |
+
class MFCC(nn.Module):
|
333 |
+
def __init__(self, n_mfcc=40, n_mels=80):
|
334 |
+
super(MFCC, self).__init__()
|
335 |
+
self.n_mfcc = n_mfcc
|
336 |
+
self.n_mels = n_mels
|
337 |
+
self.norm = 'ortho'
|
338 |
+
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
339 |
+
self.register_buffer('dct_mat', dct_mat)
|
340 |
+
|
341 |
+
def forward(self, mel_specgram):
|
342 |
+
if len(mel_specgram.shape) == 2:
|
343 |
+
mel_specgram = mel_specgram.unsqueeze(0)
|
344 |
+
unsqueezed = True
|
345 |
+
else:
|
346 |
+
unsqueezed = False
|
347 |
+
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
348 |
+
# -> (channel, time, n_mfcc).tranpose(...)
|
349 |
+
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
350 |
+
|
351 |
+
# unpack batch
|
352 |
+
if unsqueezed:
|
353 |
+
mfcc = mfcc.squeeze(0)
|
354 |
+
return mfcc
|
modules/length_regulator.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from modules.commons import sequence_mask
|
5 |
+
|
6 |
+
|
7 |
+
class InterpolateRegulator(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
channels: int,
|
11 |
+
sampling_ratios: Tuple,
|
12 |
+
is_discrete: bool = False,
|
13 |
+
codebook_size: int = 1024, # for discrete only
|
14 |
+
out_channels: int = None,
|
15 |
+
groups: int = 1,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.sampling_ratios = sampling_ratios
|
19 |
+
out_channels = out_channels or channels
|
20 |
+
model = nn.ModuleList([])
|
21 |
+
if len(sampling_ratios) > 0:
|
22 |
+
for _ in sampling_ratios:
|
23 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
24 |
+
norm = nn.GroupNorm(groups, channels)
|
25 |
+
act = nn.Mish()
|
26 |
+
model.extend([module, norm, act])
|
27 |
+
model.append(
|
28 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
29 |
+
)
|
30 |
+
self.model = nn.Sequential(*model)
|
31 |
+
self.embedding = nn.Embedding(codebook_size, channels)
|
32 |
+
self.is_discrete = is_discrete
|
33 |
+
|
34 |
+
def forward(self, x, ylens=None):
|
35 |
+
if self.is_discrete:
|
36 |
+
x = self.embedding(x)
|
37 |
+
# x in (B, T, D)
|
38 |
+
mask = sequence_mask(ylens).unsqueeze(-1)
|
39 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
40 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
41 |
+
olens = ylens
|
42 |
+
return out * mask, olens
|
modules/wavenet.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from modules.encodec import SConv1d
|
7 |
+
|
8 |
+
from . import commons
|
9 |
+
LRELU_SLOPE = 0.1
|
10 |
+
|
11 |
+
class LayerNorm(nn.Module):
|
12 |
+
def __init__(self, channels, eps=1e-5):
|
13 |
+
super().__init__()
|
14 |
+
self.channels = channels
|
15 |
+
self.eps = eps
|
16 |
+
|
17 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
18 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = x.transpose(1, -1)
|
22 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
23 |
+
return x.transpose(1, -1)
|
24 |
+
|
25 |
+
|
26 |
+
class ConvReluNorm(nn.Module):
|
27 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
28 |
+
super().__init__()
|
29 |
+
self.in_channels = in_channels
|
30 |
+
self.hidden_channels = hidden_channels
|
31 |
+
self.out_channels = out_channels
|
32 |
+
self.kernel_size = kernel_size
|
33 |
+
self.n_layers = n_layers
|
34 |
+
self.p_dropout = p_dropout
|
35 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
36 |
+
|
37 |
+
self.conv_layers = nn.ModuleList()
|
38 |
+
self.norm_layers = nn.ModuleList()
|
39 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
40 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
41 |
+
self.relu_drop = nn.Sequential(
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.Dropout(p_dropout))
|
44 |
+
for _ in range(n_layers - 1):
|
45 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
46 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
47 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
48 |
+
self.proj.weight.data.zero_()
|
49 |
+
self.proj.bias.data.zero_()
|
50 |
+
|
51 |
+
def forward(self, x, x_mask):
|
52 |
+
x_org = x
|
53 |
+
for i in range(self.n_layers):
|
54 |
+
x = self.conv_layers[i](x * x_mask)
|
55 |
+
x = self.norm_layers[i](x)
|
56 |
+
x = self.relu_drop(x)
|
57 |
+
x = x_org + self.proj(x)
|
58 |
+
return x * x_mask
|
59 |
+
|
60 |
+
|
61 |
+
class DDSConv(nn.Module):
|
62 |
+
"""
|
63 |
+
Dialted and Depth-Separable Convolution
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
67 |
+
super().__init__()
|
68 |
+
self.channels = channels
|
69 |
+
self.kernel_size = kernel_size
|
70 |
+
self.n_layers = n_layers
|
71 |
+
self.p_dropout = p_dropout
|
72 |
+
|
73 |
+
self.drop = nn.Dropout(p_dropout)
|
74 |
+
self.convs_sep = nn.ModuleList()
|
75 |
+
self.convs_1x1 = nn.ModuleList()
|
76 |
+
self.norms_1 = nn.ModuleList()
|
77 |
+
self.norms_2 = nn.ModuleList()
|
78 |
+
for i in range(n_layers):
|
79 |
+
dilation = kernel_size ** i
|
80 |
+
padding = (kernel_size * dilation - dilation) // 2
|
81 |
+
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
82 |
+
groups=channels, dilation=dilation, padding=padding
|
83 |
+
))
|
84 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
85 |
+
self.norms_1.append(LayerNorm(channels))
|
86 |
+
self.norms_2.append(LayerNorm(channels))
|
87 |
+
|
88 |
+
def forward(self, x, x_mask, g=None):
|
89 |
+
if g is not None:
|
90 |
+
x = x + g
|
91 |
+
for i in range(self.n_layers):
|
92 |
+
y = self.convs_sep[i](x * x_mask)
|
93 |
+
y = self.norms_1[i](y)
|
94 |
+
y = F.gelu(y)
|
95 |
+
y = self.convs_1x1[i](y)
|
96 |
+
y = self.norms_2[i](y)
|
97 |
+
y = F.gelu(y)
|
98 |
+
y = self.drop(y)
|
99 |
+
x = x + y
|
100 |
+
return x * x_mask
|
101 |
+
|
102 |
+
|
103 |
+
class WN(torch.nn.Module):
|
104 |
+
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False):
|
105 |
+
super(WN, self).__init__()
|
106 |
+
conv1d_type = SConv1d
|
107 |
+
assert (kernel_size % 2 == 1)
|
108 |
+
self.hidden_channels = hidden_channels
|
109 |
+
self.kernel_size = kernel_size,
|
110 |
+
self.dilation_rate = dilation_rate
|
111 |
+
self.n_layers = n_layers
|
112 |
+
self.gin_channels = gin_channels
|
113 |
+
self.p_dropout = p_dropout
|
114 |
+
|
115 |
+
self.in_layers = torch.nn.ModuleList()
|
116 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
117 |
+
self.drop = nn.Dropout(p_dropout)
|
118 |
+
|
119 |
+
if gin_channels != 0:
|
120 |
+
self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm')
|
121 |
+
|
122 |
+
for i in range(n_layers):
|
123 |
+
dilation = dilation_rate ** i
|
124 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
125 |
+
in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation,
|
126 |
+
padding=padding, norm='weight_norm', causal=causal)
|
127 |
+
self.in_layers.append(in_layer)
|
128 |
+
|
129 |
+
# last one is not necessary
|
130 |
+
if i < n_layers - 1:
|
131 |
+
res_skip_channels = 2 * hidden_channels
|
132 |
+
else:
|
133 |
+
res_skip_channels = hidden_channels
|
134 |
+
|
135 |
+
res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal)
|
136 |
+
self.res_skip_layers.append(res_skip_layer)
|
137 |
+
|
138 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
139 |
+
output = torch.zeros_like(x)
|
140 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
141 |
+
|
142 |
+
if g is not None:
|
143 |
+
g = self.cond_layer(g)
|
144 |
+
|
145 |
+
for i in range(self.n_layers):
|
146 |
+
x_in = self.in_layers[i](x)
|
147 |
+
if g is not None:
|
148 |
+
cond_offset = i * 2 * self.hidden_channels
|
149 |
+
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
150 |
+
else:
|
151 |
+
g_l = torch.zeros_like(x_in)
|
152 |
+
|
153 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(
|
154 |
+
x_in,
|
155 |
+
g_l,
|
156 |
+
n_channels_tensor)
|
157 |
+
acts = self.drop(acts)
|
158 |
+
|
159 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
160 |
+
if i < self.n_layers - 1:
|
161 |
+
res_acts = res_skip_acts[:, :self.hidden_channels, :]
|
162 |
+
x = (x + res_acts) * x_mask
|
163 |
+
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
164 |
+
else:
|
165 |
+
output = output + res_skip_acts
|
166 |
+
return output * x_mask
|
167 |
+
|
168 |
+
def remove_weight_norm(self):
|
169 |
+
if self.gin_channels != 0:
|
170 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
171 |
+
for l in self.in_layers:
|
172 |
+
torch.nn.utils.remove_weight_norm(l)
|
173 |
+
for l in self.res_skip_layers:
|
174 |
+
torch.nn.utils.remove_weight_norm(l)
|