zhzluke96
commited on
Commit
·
d2b7e94
1
Parent(s):
9d9fe0d
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.webui +2 -2
- README.md +2 -2
- launch.py +11 -10
- modules/ChatTTS/ChatTTS/__init__.py +1 -1
- modules/ChatTTS/ChatTTS/core.py +7 -7
- modules/ChatTTS/ChatTTS/infer/api.py +1 -0
- modules/ChatTTS/ChatTTS/model/dvae.py +79 -48
- modules/ChatTTS/ChatTTS/model/gpt.py +167 -87
- modules/ChatTTS/ChatTTS/utils/infer_utils.py +1 -0
- modules/ChatTTS/ChatTTS/utils/io_utils.py +6 -6
- modules/Denoiser/AudioDenoiser.py +5 -3
- modules/Denoiser/AudioNosiseModel.py +2 -3
- modules/Enhancer/ResembleEnhance.py +6 -9
- modules/SentenceSplitter.py +1 -0
- modules/SynthesizeSegments.py +24 -18
- modules/api/Api.py +3 -5
- modules/api/api_setup.py +11 -13
- modules/api/impl/google_api.py +2 -6
- modules/api/impl/handler/AudioHandler.py +2 -1
- modules/api/impl/handler/SSMLHandler.py +3 -3
- modules/api/impl/handler/TTSHandler.py +2 -2
- modules/api/impl/model/enhancer_model.py +1 -0
- modules/api/impl/models_api.py +1 -1
- modules/api/impl/openai_api.py +6 -11
- modules/api/impl/ping_api.py +1 -2
- modules/api/impl/refiner_api.py +0 -3
- modules/api/impl/speaker_api.py +3 -2
- modules/api/impl/ssml_api.py +3 -8
- modules/api/impl/style_api.py +1 -1
- modules/api/impl/tts_api.py +3 -7
- modules/api/impl/xtts_v2_api.py +5 -7
- modules/api/utils.py +3 -7
- modules/api/worker.py +2 -1
- modules/config.py +2 -2
- modules/data.py +0 -1
- modules/denoise.py +3 -5
- modules/devices/devices.py +4 -3
- modules/devices/mac_devices.py +3 -2
- modules/ffmpeg_env.py +2 -1
- modules/finetune/train_speaker.py +8 -5
- modules/finetune/utils/dataset.py +6 -6
- modules/finetune/utils/logger.py +3 -4
- modules/generate_audio.py +7 -10
- modules/models.py +5 -5
- modules/normalization.py +5 -3
- modules/prompts/news_oral_prompt.txt +23 -4
- modules/refiner.py +1 -2
- modules/repos_static/resemble_enhance/common.py +3 -1
- modules/repos_static/resemble_enhance/data/dataset.py +21 -7
- modules/repos_static/resemble_enhance/data/distorter/base.py +1 -1
.env.webui
CHANGED
@@ -14,9 +14,9 @@ DEBUG_GENERATE=True
|
|
14 |
PRELOAD_MODELS=True
|
15 |
|
16 |
# Text-to-Speech (TTS) configuration
|
17 |
-
TTS_MAX_LEN=
|
18 |
SSML_MAX_LEN=3000
|
19 |
MAX_BATCH_SIZE=12
|
20 |
|
21 |
-
V_GIT_TAG="🤗hf(0.6.1
|
22 |
V_GIT_COMMIT=main
|
|
|
14 |
PRELOAD_MODELS=True
|
15 |
|
16 |
# Text-to-Speech (TTS) configuration
|
17 |
+
TTS_MAX_LEN=2000
|
18 |
SSML_MAX_LEN=3000
|
19 |
MAX_BATCH_SIZE=12
|
20 |
|
21 |
+
V_GIT_TAG="🤗hf(0.6.1)"
|
22 |
V_GIT_COMMIT=main
|
README.md
CHANGED
@@ -16,7 +16,7 @@ sdk_version: 4.36.1
|
|
16 |
|
17 |
| 类型 | 最大字符数 |
|
18 |
|------|-----------|
|
19 |
-
| TTS |
|
20 |
| SSML | 3000 字符(不计算 SSML 标签,只计算文本) |
|
21 |
|
22 |
# HuggingFace Space Limit
|
@@ -25,7 +25,7 @@ Due to the runtime limit for GPU usage on HuggingFace, extremely long tasks will
|
|
25 |
|
26 |
| Type | Maximum Characters |
|
27 |
|------|---------------------|
|
28 |
-
| TTS |
|
29 |
| SSML | 3000 characters (excluding SSML tags, only counting text) |
|
30 |
|
31 |
# 🗣️ ChatTTS-Forge
|
|
|
16 |
|
17 |
| 类型 | 最大字符数 |
|
18 |
|------|-----------|
|
19 |
+
| TTS | 2000 字符 |
|
20 |
| SSML | 3000 字符(不计算 SSML 标签,只计算文本) |
|
21 |
|
22 |
# HuggingFace Space Limit
|
|
|
25 |
|
26 |
| Type | Maximum Characters |
|
27 |
|------|---------------------|
|
28 |
+
| TTS | 2000 characters |
|
29 |
| SSML | 3000 characters (excluding SSML tags, only counting text) |
|
30 |
|
31 |
# 🗣️ ChatTTS-Forge
|
launch.py
CHANGED
@@ -1,23 +1,24 @@
|
|
1 |
-
import os
|
2 |
import logging
|
|
|
3 |
|
4 |
-
from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
|
5 |
from modules.ffmpeg_env import setup_ffmpeg_path
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
)
|
|
|
|
|
|
|
12 |
|
13 |
import argparse
|
|
|
14 |
import uvicorn
|
15 |
|
16 |
-
from modules import
|
17 |
from modules.utils import env
|
18 |
|
19 |
-
from fastapi import FastAPI
|
20 |
-
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
23 |
if __name__ == "__main__":
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
|
|
|
4 |
from modules.ffmpeg_env import setup_ffmpeg_path
|
5 |
|
6 |
+
try:
|
7 |
+
setup_ffmpeg_path()
|
8 |
+
logging.basicConfig(
|
9 |
+
level=os.getenv("LOG_LEVEL", "INFO"),
|
10 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
11 |
+
)
|
12 |
+
except BaseException:
|
13 |
+
pass
|
14 |
|
15 |
import argparse
|
16 |
+
|
17 |
import uvicorn
|
18 |
|
19 |
+
from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
|
20 |
from modules.utils import env
|
21 |
|
|
|
|
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
if __name__ == "__main__":
|
modules/ChatTTS/ChatTTS/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .core import Chat
|
|
|
1 |
+
from .core import Chat
|
modules/ChatTTS/ChatTTS/core.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
import os
|
2 |
import logging
|
3 |
-
|
4 |
|
5 |
import torch
|
|
|
|
|
6 |
from vocos import Vocos
|
|
|
|
|
7 |
from .model.dvae import DVAE
|
8 |
from .model.gpt import GPT_warpper
|
9 |
from .utils.infer_utils import (
|
10 |
-
count_invalid_characters,
|
11 |
-
detect_language,
|
12 |
apply_character_map,
|
13 |
apply_half2full_map,
|
|
|
|
|
14 |
)
|
15 |
from .utils.io_utils import get_latest_modified_file
|
16 |
-
from .infer.api import refine_text, infer_code
|
17 |
-
|
18 |
-
from huggingface_hub import snapshot_download
|
19 |
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
|
4 |
import torch
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
from omegaconf import OmegaConf
|
7 |
from vocos import Vocos
|
8 |
+
|
9 |
+
from .infer.api import infer_code, refine_text
|
10 |
from .model.dvae import DVAE
|
11 |
from .model.gpt import GPT_warpper
|
12 |
from .utils.infer_utils import (
|
|
|
|
|
13 |
apply_character_map,
|
14 |
apply_half2full_map,
|
15 |
+
count_invalid_characters,
|
16 |
+
detect_language,
|
17 |
)
|
18 |
from .utils.io_utils import get_latest_modified_file
|
|
|
|
|
|
|
19 |
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
|
modules/ChatTTS/ChatTTS/infer/api.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
|
|
4 |
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
5 |
|
6 |
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
4 |
+
|
5 |
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
|
6 |
|
7 |
|
modules/ChatTTS/ChatTTS/model/dvae.py
CHANGED
@@ -1,28 +1,36 @@
|
|
1 |
import math
|
2 |
-
from einops import rearrange
|
3 |
-
from vector_quantize_pytorch import GroupedResidualFSQ
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
|
|
|
|
|
|
8 |
|
9 |
class ConvNeXtBlock(nn.Module):
|
10 |
def __init__(
|
11 |
self,
|
12 |
dim: int,
|
13 |
intermediate_dim: int,
|
14 |
-
kernel,
|
|
|
15 |
layer_scale_init_value: float = 1e-6,
|
16 |
):
|
17 |
# ConvNeXt Block copied from Vocos.
|
18 |
super().__init__()
|
19 |
-
self.dwconv = nn.Conv1d(
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
25 |
-
self.pwconv1 = nn.Linear(
|
|
|
|
|
26 |
self.act = nn.GELU()
|
27 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
28 |
self.gamma = (
|
@@ -31,7 +39,7 @@ class ConvNeXtBlock(nn.Module):
|
|
31 |
else None
|
32 |
)
|
33 |
|
34 |
-
def forward(self, x: torch.Tensor, cond
|
35 |
residual = x
|
36 |
x = self.dwconv(x)
|
37 |
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
@@ -45,14 +53,11 @@ class ConvNeXtBlock(nn.Module):
|
|
45 |
|
46 |
x = residual + x
|
47 |
return x
|
48 |
-
|
49 |
|
50 |
|
51 |
class GFSQ(nn.Module):
|
52 |
|
53 |
-
def __init__(self,
|
54 |
-
dim, levels, G, R, eps=1e-5, transpose = True
|
55 |
-
):
|
56 |
super(GFSQ, self).__init__()
|
57 |
self.quantizer = GroupedResidualFSQ(
|
58 |
dim=dim,
|
@@ -65,50 +70,74 @@ class GFSQ(nn.Module):
|
|
65 |
self.transpose = transpose
|
66 |
self.G = G
|
67 |
self.R = R
|
68 |
-
|
69 |
def _embed(self, x):
|
70 |
if self.transpose:
|
71 |
-
x = x.transpose(1,2)
|
72 |
x = rearrange(
|
73 |
-
x,
|
74 |
-
|
|
|
|
|
|
|
75 |
feat = self.quantizer.get_output_from_indices(x)
|
76 |
-
return feat.transpose(1,2) if self.transpose else feat
|
77 |
-
|
78 |
-
def forward(
|
|
|
|
|
|
|
79 |
if self.transpose:
|
80 |
-
x = x.transpose(1,2)
|
81 |
feat, ind = self.quantizer(x)
|
82 |
ind = rearrange(
|
83 |
-
ind,
|
84 |
-
|
|
|
85 |
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
|
86 |
-
e_mean = torch.mean(embed_onehot, dim=[0,1])
|
87 |
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
88 |
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
89 |
-
|
90 |
return (
|
91 |
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
92 |
-
feat.transpose(1,2) if self.transpose else feat,
|
93 |
perplexity,
|
94 |
None,
|
95 |
-
ind.transpose(1,2) if self.transpose else ind,
|
96 |
)
|
97 |
-
|
|
|
98 |
class DVAEDecoder(nn.Module):
|
99 |
-
def __init__(
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
super().__init__()
|
104 |
self.up = up
|
105 |
self.conv_in = nn.Sequential(
|
106 |
-
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
107 |
-
nn.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
)
|
109 |
-
self.decoder_block = nn.ModuleList([
|
110 |
-
ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
|
111 |
-
for _ in range(n_layer)])
|
112 |
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
113 |
|
114 |
def forward(self, input, conditioning=None):
|
@@ -117,17 +146,15 @@ class DVAEDecoder(nn.Module):
|
|
117 |
x = self.conv_in(x)
|
118 |
for f in self.decoder_block:
|
119 |
x = f(x, conditioning)
|
120 |
-
|
121 |
x = self.conv_out(x)
|
122 |
return x.transpose(1, 2)
|
123 |
-
|
124 |
|
125 |
class DVAE(nn.Module):
|
126 |
-
def __init__(
|
127 |
-
self, decoder_config, vq_config, dim=512
|
128 |
-
):
|
129 |
super().__init__()
|
130 |
-
self.register_buffer(
|
131 |
|
132 |
self.decoder = DVAEDecoder(**decoder_config)
|
133 |
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
@@ -142,10 +169,14 @@ class DVAE(nn.Module):
|
|
142 |
vq_feats = self.vq_layer._embed(inp)
|
143 |
else:
|
144 |
vq_feats = inp.detach().clone()
|
145 |
-
|
146 |
-
vq_feats =
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
149 |
|
150 |
vq_feats = vq_feats.transpose(1, 2)
|
151 |
dec_out = self.decoder(input=vq_feats)
|
|
|
1 |
import math
|
|
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from vector_quantize_pytorch import GroupedResidualFSQ
|
8 |
+
|
9 |
|
10 |
class ConvNeXtBlock(nn.Module):
|
11 |
def __init__(
|
12 |
self,
|
13 |
dim: int,
|
14 |
intermediate_dim: int,
|
15 |
+
kernel,
|
16 |
+
dilation,
|
17 |
layer_scale_init_value: float = 1e-6,
|
18 |
):
|
19 |
# ConvNeXt Block copied from Vocos.
|
20 |
super().__init__()
|
21 |
+
self.dwconv = nn.Conv1d(
|
22 |
+
dim,
|
23 |
+
dim,
|
24 |
+
kernel_size=kernel,
|
25 |
+
padding=dilation * (kernel // 2),
|
26 |
+
dilation=dilation,
|
27 |
+
groups=dim,
|
28 |
+
) # depthwise conv
|
29 |
+
|
30 |
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
31 |
+
self.pwconv1 = nn.Linear(
|
32 |
+
dim, intermediate_dim
|
33 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
34 |
self.act = nn.GELU()
|
35 |
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
36 |
self.gamma = (
|
|
|
39 |
else None
|
40 |
)
|
41 |
|
42 |
+
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
|
43 |
residual = x
|
44 |
x = self.dwconv(x)
|
45 |
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
|
53 |
|
54 |
x = residual + x
|
55 |
return x
|
|
|
56 |
|
57 |
|
58 |
class GFSQ(nn.Module):
|
59 |
|
60 |
+
def __init__(self, dim, levels, G, R, eps=1e-5, transpose=True):
|
|
|
|
|
61 |
super(GFSQ, self).__init__()
|
62 |
self.quantizer = GroupedResidualFSQ(
|
63 |
dim=dim,
|
|
|
70 |
self.transpose = transpose
|
71 |
self.G = G
|
72 |
self.R = R
|
73 |
+
|
74 |
def _embed(self, x):
|
75 |
if self.transpose:
|
76 |
+
x = x.transpose(1, 2)
|
77 |
x = rearrange(
|
78 |
+
x,
|
79 |
+
"b t (g r) -> g b t r",
|
80 |
+
g=self.G,
|
81 |
+
r=self.R,
|
82 |
+
)
|
83 |
feat = self.quantizer.get_output_from_indices(x)
|
84 |
+
return feat.transpose(1, 2) if self.transpose else feat
|
85 |
+
|
86 |
+
def forward(
|
87 |
+
self,
|
88 |
+
x,
|
89 |
+
):
|
90 |
if self.transpose:
|
91 |
+
x = x.transpose(1, 2)
|
92 |
feat, ind = self.quantizer(x)
|
93 |
ind = rearrange(
|
94 |
+
ind,
|
95 |
+
"g b t r ->b t (g r)",
|
96 |
+
)
|
97 |
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
|
98 |
+
e_mean = torch.mean(embed_onehot, dim=[0, 1])
|
99 |
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
100 |
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
101 |
+
|
102 |
return (
|
103 |
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
104 |
+
feat.transpose(1, 2) if self.transpose else feat,
|
105 |
perplexity,
|
106 |
None,
|
107 |
+
ind.transpose(1, 2) if self.transpose else ind,
|
108 |
)
|
109 |
+
|
110 |
+
|
111 |
class DVAEDecoder(nn.Module):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
idim,
|
115 |
+
odim,
|
116 |
+
n_layer=12,
|
117 |
+
bn_dim=64,
|
118 |
+
hidden=256,
|
119 |
+
kernel=7,
|
120 |
+
dilation=2,
|
121 |
+
up=False,
|
122 |
+
):
|
123 |
super().__init__()
|
124 |
self.up = up
|
125 |
self.conv_in = nn.Sequential(
|
126 |
+
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
127 |
+
nn.GELU(),
|
128 |
+
nn.Conv1d(bn_dim, hidden, 3, 1, 1),
|
129 |
+
)
|
130 |
+
self.decoder_block = nn.ModuleList(
|
131 |
+
[
|
132 |
+
ConvNeXtBlock(
|
133 |
+
hidden,
|
134 |
+
hidden * 4,
|
135 |
+
kernel,
|
136 |
+
dilation,
|
137 |
+
)
|
138 |
+
for _ in range(n_layer)
|
139 |
+
]
|
140 |
)
|
|
|
|
|
|
|
141 |
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
142 |
|
143 |
def forward(self, input, conditioning=None):
|
|
|
146 |
x = self.conv_in(x)
|
147 |
for f in self.decoder_block:
|
148 |
x = f(x, conditioning)
|
149 |
+
|
150 |
x = self.conv_out(x)
|
151 |
return x.transpose(1, 2)
|
152 |
+
|
153 |
|
154 |
class DVAE(nn.Module):
|
155 |
+
def __init__(self, decoder_config, vq_config, dim=512):
|
|
|
|
|
156 |
super().__init__()
|
157 |
+
self.register_buffer("coef", torch.randn(1, 100, 1))
|
158 |
|
159 |
self.decoder = DVAEDecoder(**decoder_config)
|
160 |
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
|
|
169 |
vq_feats = self.vq_layer._embed(inp)
|
170 |
else:
|
171 |
vq_feats = inp.detach().clone()
|
172 |
+
|
173 |
+
vq_feats = (
|
174 |
+
vq_feats.view(
|
175 |
+
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
|
176 |
+
)
|
177 |
+
.permute(0, 2, 3, 1)
|
178 |
+
.flatten(2)
|
179 |
+
)
|
180 |
|
181 |
vq_feats = vq_feats.transpose(1, 2)
|
182 |
dec_out = self.decoder(input=vq_feats)
|
modules/ChatTTS/ChatTTS/model/gpt.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1 |
import os
|
|
|
2 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
3 |
|
4 |
import logging
|
5 |
-
from tqdm import tqdm
|
6 |
-
from einops import rearrange
|
7 |
-
from transformers.cache_utils import Cache
|
8 |
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
import torch.nn.utils.parametrize as P
|
|
|
13 |
from torch.nn.utils.parametrizations import weight_norm
|
14 |
-
from
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
class LlamaMLP(nn.Module):
|
18 |
def __init__(self, hidden_size, intermediate_size):
|
19 |
super().__init__()
|
@@ -27,70 +28,106 @@ class LlamaMLP(nn.Module):
|
|
27 |
def forward(self, x):
|
28 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
29 |
return down_proj
|
30 |
-
|
31 |
-
|
32 |
class GPT_warpper(nn.Module):
|
33 |
def __init__(
|
34 |
-
self,
|
35 |
-
gpt_config,
|
36 |
num_audio_tokens,
|
37 |
num_text_tokens,
|
38 |
num_vq=4,
|
39 |
**kwargs,
|
40 |
-
|
41 |
super().__init__()
|
42 |
|
43 |
self.logger = logging.getLogger(__name__)
|
44 |
self.gpt = self.build_model(gpt_config)
|
45 |
-
self.model_dim = self.gpt.config.hidden_size
|
46 |
|
47 |
self.num_vq = num_vq
|
48 |
-
self.emb_code = nn.ModuleList(
|
|
|
|
|
49 |
self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
|
50 |
-
self.head_text = weight_norm(
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
def build_model(self, config):
|
54 |
-
|
55 |
configuration = LlamaConfig(**config)
|
56 |
model = LlamaModel(configuration)
|
57 |
del model.embed_tokens
|
58 |
-
|
59 |
return model
|
60 |
-
|
61 |
def get_emb(self, input_ids, text_mask, **kwargs):
|
62 |
|
63 |
emb_text = self.emb_text(input_ids[text_mask][:, 0])
|
64 |
-
|
65 |
-
emb_code = [
|
|
|
|
|
66 |
emb_code = torch.stack(emb_code, 2).sum(2)
|
67 |
-
|
68 |
-
emb = torch.zeros(
|
|
|
|
|
|
|
|
|
69 |
emb[text_mask] = emb_text
|
70 |
emb[~text_mask] = emb_code.to(emb.dtype)
|
71 |
-
|
72 |
return emb
|
73 |
-
|
74 |
def prepare_inputs_for_generation(
|
75 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
):
|
77 |
# With static cache, the `past_key_values` is None
|
78 |
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
79 |
has_static_cache = False
|
80 |
if past_key_values is None:
|
81 |
-
past_key_values = getattr(
|
|
|
|
|
82 |
has_static_cache = past_key_values is not None
|
83 |
|
84 |
past_length = 0
|
85 |
if past_key_values is not None:
|
86 |
if isinstance(past_key_values, Cache):
|
87 |
-
past_length =
|
|
|
|
|
|
|
|
|
88 |
max_cache_length = (
|
89 |
-
torch.tensor(
|
|
|
|
|
90 |
if past_key_values.get_max_length() is not None
|
91 |
else None
|
92 |
)
|
93 |
-
cache_length =
|
|
|
|
|
|
|
|
|
94 |
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
95 |
else:
|
96 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
@@ -100,7 +137,10 @@ class GPT_warpper(nn.Module):
|
|
100 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
101 |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
102 |
# input)
|
103 |
-
if
|
|
|
|
|
|
|
104 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
105 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
106 |
# input_ids based on the past_length.
|
@@ -133,9 +173,13 @@ class GPT_warpper(nn.Module):
|
|
133 |
# TODO: use `next_tokens` directly instead.
|
134 |
model_inputs = {"input_ids": input_ids.contiguous()}
|
135 |
|
136 |
-
input_length =
|
|
|
|
|
137 |
if cache_position is None:
|
138 |
-
cache_position = torch.arange(
|
|
|
|
|
139 |
else:
|
140 |
cache_position = cache_position[-input_length:]
|
141 |
|
@@ -152,118 +196,154 @@ class GPT_warpper(nn.Module):
|
|
152 |
}
|
153 |
)
|
154 |
return model_inputs
|
155 |
-
|
156 |
def generate(
|
157 |
-
self,
|
158 |
-
emb,
|
159 |
-
inputs_ids,
|
160 |
-
temperature,
|
161 |
-
eos_token,
|
162 |
-
attention_mask
|
163 |
-
max_new_token
|
164 |
-
min_new_token
|
165 |
-
LogitsWarpers
|
166 |
-
LogitsProcessors
|
167 |
infer_text=False,
|
168 |
return_attn=False,
|
169 |
return_hidden=False,
|
170 |
-
disable_tqdm=False
|
171 |
):
|
172 |
if disable_tqdm:
|
173 |
tqdm = lambda x: x
|
174 |
else:
|
175 |
from tqdm import tqdm
|
176 |
-
|
177 |
-
with torch.no_grad():
|
178 |
-
|
179 |
attentions = []
|
180 |
hiddens = []
|
181 |
-
|
182 |
-
start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
|
|
|
|
|
183 |
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
|
184 |
-
|
185 |
temperature = temperature[None].expand(inputs_ids.shape[0], -1)
|
186 |
temperature = rearrange(temperature, "b n -> (b n) 1")
|
187 |
|
188 |
-
attention_mask_cache = torch.ones(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
if attention_mask is not None:
|
190 |
-
attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
|
191 |
-
|
192 |
for i in tqdm(range(max_new_token)):
|
193 |
if finish.all():
|
194 |
continue
|
195 |
-
|
196 |
-
model_input = self.prepare_inputs_for_generation(
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
200 |
if i == 0:
|
201 |
-
model_input[
|
202 |
else:
|
203 |
if infer_text:
|
204 |
-
model_input[
|
|
|
|
|
205 |
else:
|
206 |
-
code_emb = [
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
|
211 |
attentions.append(outputs.attentions)
|
212 |
-
hidden_states = outputs[0]
|
213 |
if return_hidden:
|
214 |
hiddens.append(hidden_states[:, -1])
|
215 |
|
216 |
with P.cached():
|
217 |
if infer_text:
|
218 |
-
logits = self.head_text(hidden_states)
|
219 |
else:
|
220 |
-
logits = torch.stack(
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
logits = logits[:, -1].float()
|
223 |
|
224 |
if not infer_text:
|
225 |
logits = rearrange(logits, "b c n -> (b n) c")
|
226 |
-
logits_token = rearrange(
|
|
|
|
|
227 |
else:
|
228 |
logits_token = inputs_ids[:, start_idx:, 0]
|
229 |
-
|
230 |
logits = logits / temperature
|
231 |
-
|
232 |
for logitsProcessors in LogitsProcessors:
|
233 |
logits = logitsProcessors(logits_token, logits)
|
234 |
-
|
235 |
for logitsWarpers in LogitsWarpers:
|
236 |
logits = logitsWarpers(logits_token, logits)
|
237 |
-
|
238 |
if i < min_new_token:
|
239 |
logits[:, eos_token] = -torch.inf
|
240 |
-
|
241 |
scores = F.softmax(logits, dim=-1)
|
242 |
-
|
243 |
idx_next = torch.multinomial(scores, num_samples=1)
|
244 |
-
|
245 |
if not infer_text:
|
246 |
idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
247 |
finish = finish | (idx_next == eos_token).any(1)
|
248 |
inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
|
249 |
else:
|
250 |
finish = finish | (idx_next == eos_token).any(1)
|
251 |
-
inputs_ids = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
end_idx = end_idx + (~finish).int()
|
254 |
-
|
255 |
-
inputs_ids = [
|
|
|
|
|
|
|
256 |
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
|
257 |
-
|
258 |
if return_hidden:
|
259 |
hiddens = torch.stack(hiddens, 1)
|
260 |
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
|
261 |
-
|
262 |
if not finish.all():
|
263 |
-
self.logger.warn(
|
264 |
-
|
|
|
|
|
265 |
return {
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
}
|
|
|
1 |
import os
|
2 |
+
|
3 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
4 |
|
5 |
import logging
|
|
|
|
|
|
|
6 |
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
import torch.nn.utils.parametrize as P
|
11 |
+
from einops import rearrange
|
12 |
from torch.nn.utils.parametrizations import weight_norm
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import LlamaConfig, LlamaModel
|
15 |
+
from transformers.cache_utils import Cache
|
16 |
+
|
17 |
+
|
18 |
class LlamaMLP(nn.Module):
|
19 |
def __init__(self, hidden_size, intermediate_size):
|
20 |
super().__init__()
|
|
|
28 |
def forward(self, x):
|
29 |
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
30 |
return down_proj
|
31 |
+
|
32 |
+
|
33 |
class GPT_warpper(nn.Module):
|
34 |
def __init__(
|
35 |
+
self,
|
36 |
+
gpt_config,
|
37 |
num_audio_tokens,
|
38 |
num_text_tokens,
|
39 |
num_vq=4,
|
40 |
**kwargs,
|
41 |
+
):
|
42 |
super().__init__()
|
43 |
|
44 |
self.logger = logging.getLogger(__name__)
|
45 |
self.gpt = self.build_model(gpt_config)
|
46 |
+
self.model_dim = self.gpt.config.hidden_size
|
47 |
|
48 |
self.num_vq = num_vq
|
49 |
+
self.emb_code = nn.ModuleList(
|
50 |
+
[nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)]
|
51 |
+
)
|
52 |
self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
|
53 |
+
self.head_text = weight_norm(
|
54 |
+
nn.Linear(self.model_dim, num_text_tokens, bias=False), name="weight"
|
55 |
+
)
|
56 |
+
self.head_code = nn.ModuleList(
|
57 |
+
[
|
58 |
+
weight_norm(
|
59 |
+
nn.Linear(self.model_dim, num_audio_tokens, bias=False),
|
60 |
+
name="weight",
|
61 |
+
)
|
62 |
+
for i in range(self.num_vq)
|
63 |
+
]
|
64 |
+
)
|
65 |
|
66 |
def build_model(self, config):
|
67 |
+
|
68 |
configuration = LlamaConfig(**config)
|
69 |
model = LlamaModel(configuration)
|
70 |
del model.embed_tokens
|
71 |
+
|
72 |
return model
|
73 |
+
|
74 |
def get_emb(self, input_ids, text_mask, **kwargs):
|
75 |
|
76 |
emb_text = self.emb_text(input_ids[text_mask][:, 0])
|
77 |
+
|
78 |
+
emb_code = [
|
79 |
+
self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)
|
80 |
+
]
|
81 |
emb_code = torch.stack(emb_code, 2).sum(2)
|
82 |
+
|
83 |
+
emb = torch.zeros(
|
84 |
+
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
|
85 |
+
device=emb_text.device,
|
86 |
+
dtype=emb_text.dtype,
|
87 |
+
)
|
88 |
emb[text_mask] = emb_text
|
89 |
emb[~text_mask] = emb_code.to(emb.dtype)
|
90 |
+
|
91 |
return emb
|
92 |
+
|
93 |
def prepare_inputs_for_generation(
|
94 |
+
self,
|
95 |
+
input_ids,
|
96 |
+
past_key_values=None,
|
97 |
+
attention_mask=None,
|
98 |
+
inputs_embeds=None,
|
99 |
+
cache_position=None,
|
100 |
+
**kwargs,
|
101 |
):
|
102 |
# With static cache, the `past_key_values` is None
|
103 |
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
104 |
has_static_cache = False
|
105 |
if past_key_values is None:
|
106 |
+
past_key_values = getattr(
|
107 |
+
self.gpt.layers[0].self_attn, "past_key_value", None
|
108 |
+
)
|
109 |
has_static_cache = past_key_values is not None
|
110 |
|
111 |
past_length = 0
|
112 |
if past_key_values is not None:
|
113 |
if isinstance(past_key_values, Cache):
|
114 |
+
past_length = (
|
115 |
+
cache_position[0]
|
116 |
+
if cache_position is not None
|
117 |
+
else past_key_values.get_seq_length()
|
118 |
+
)
|
119 |
max_cache_length = (
|
120 |
+
torch.tensor(
|
121 |
+
past_key_values.get_max_length(), device=input_ids.device
|
122 |
+
)
|
123 |
if past_key_values.get_max_length() is not None
|
124 |
else None
|
125 |
)
|
126 |
+
cache_length = (
|
127 |
+
past_length
|
128 |
+
if max_cache_length is None
|
129 |
+
else torch.min(max_cache_length, past_length)
|
130 |
+
)
|
131 |
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
132 |
else:
|
133 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
|
|
137 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
138 |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
139 |
# input)
|
140 |
+
if (
|
141 |
+
attention_mask is not None
|
142 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
143 |
+
):
|
144 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
145 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
146 |
# input_ids based on the past_length.
|
|
|
173 |
# TODO: use `next_tokens` directly instead.
|
174 |
model_inputs = {"input_ids": input_ids.contiguous()}
|
175 |
|
176 |
+
input_length = (
|
177 |
+
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
178 |
+
)
|
179 |
if cache_position is None:
|
180 |
+
cache_position = torch.arange(
|
181 |
+
past_length, past_length + input_length, device=input_ids.device
|
182 |
+
)
|
183 |
else:
|
184 |
cache_position = cache_position[-input_length:]
|
185 |
|
|
|
196 |
}
|
197 |
)
|
198 |
return model_inputs
|
199 |
+
|
200 |
def generate(
|
201 |
+
self,
|
202 |
+
emb,
|
203 |
+
inputs_ids,
|
204 |
+
temperature,
|
205 |
+
eos_token,
|
206 |
+
attention_mask=None,
|
207 |
+
max_new_token=2048,
|
208 |
+
min_new_token=0,
|
209 |
+
LogitsWarpers=[],
|
210 |
+
LogitsProcessors=[],
|
211 |
infer_text=False,
|
212 |
return_attn=False,
|
213 |
return_hidden=False,
|
214 |
+
disable_tqdm=False,
|
215 |
):
|
216 |
if disable_tqdm:
|
217 |
tqdm = lambda x: x
|
218 |
else:
|
219 |
from tqdm import tqdm
|
220 |
+
|
221 |
+
with torch.no_grad():
|
222 |
+
|
223 |
attentions = []
|
224 |
hiddens = []
|
225 |
+
|
226 |
+
start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
|
227 |
+
inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
|
228 |
+
)
|
229 |
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
|
230 |
+
|
231 |
temperature = temperature[None].expand(inputs_ids.shape[0], -1)
|
232 |
temperature = rearrange(temperature, "b n -> (b n) 1")
|
233 |
|
234 |
+
attention_mask_cache = torch.ones(
|
235 |
+
(
|
236 |
+
inputs_ids.shape[0],
|
237 |
+
inputs_ids.shape[1] + max_new_token,
|
238 |
+
),
|
239 |
+
dtype=torch.bool,
|
240 |
+
device=inputs_ids.device,
|
241 |
+
)
|
242 |
if attention_mask is not None:
|
243 |
+
attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
|
244 |
+
|
245 |
for i in tqdm(range(max_new_token)):
|
246 |
if finish.all():
|
247 |
continue
|
248 |
+
|
249 |
+
model_input = self.prepare_inputs_for_generation(
|
250 |
+
inputs_ids,
|
251 |
+
outputs.past_key_values if i != 0 else None,
|
252 |
+
attention_mask_cache[:, : inputs_ids.shape[1]],
|
253 |
+
use_cache=True,
|
254 |
+
)
|
255 |
+
|
256 |
if i == 0:
|
257 |
+
model_input["inputs_embeds"] = emb
|
258 |
else:
|
259 |
if infer_text:
|
260 |
+
model_input["inputs_embeds"] = self.emb_text(
|
261 |
+
model_input["input_ids"][:, :, 0]
|
262 |
+
)
|
263 |
else:
|
264 |
+
code_emb = [
|
265 |
+
self.emb_code[i](model_input["input_ids"][:, :, i])
|
266 |
+
for i in range(self.num_vq)
|
267 |
+
]
|
268 |
+
model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(3)
|
269 |
+
|
270 |
+
model_input["input_ids"] = None
|
271 |
outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
|
272 |
attentions.append(outputs.attentions)
|
273 |
+
hidden_states = outputs[0] # 🐻
|
274 |
if return_hidden:
|
275 |
hiddens.append(hidden_states[:, -1])
|
276 |
|
277 |
with P.cached():
|
278 |
if infer_text:
|
279 |
+
logits = self.head_text(hidden_states)
|
280 |
else:
|
281 |
+
logits = torch.stack(
|
282 |
+
[
|
283 |
+
self.head_code[i](hidden_states)
|
284 |
+
for i in range(self.num_vq)
|
285 |
+
],
|
286 |
+
3,
|
287 |
+
)
|
288 |
+
|
289 |
logits = logits[:, -1].float()
|
290 |
|
291 |
if not infer_text:
|
292 |
logits = rearrange(logits, "b c n -> (b n) c")
|
293 |
+
logits_token = rearrange(
|
294 |
+
inputs_ids[:, start_idx:], "b c n -> (b n) c"
|
295 |
+
)
|
296 |
else:
|
297 |
logits_token = inputs_ids[:, start_idx:, 0]
|
298 |
+
|
299 |
logits = logits / temperature
|
300 |
+
|
301 |
for logitsProcessors in LogitsProcessors:
|
302 |
logits = logitsProcessors(logits_token, logits)
|
303 |
+
|
304 |
for logitsWarpers in LogitsWarpers:
|
305 |
logits = logitsWarpers(logits_token, logits)
|
306 |
+
|
307 |
if i < min_new_token:
|
308 |
logits[:, eos_token] = -torch.inf
|
309 |
+
|
310 |
scores = F.softmax(logits, dim=-1)
|
311 |
+
|
312 |
idx_next = torch.multinomial(scores, num_samples=1)
|
313 |
+
|
314 |
if not infer_text:
|
315 |
idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
316 |
finish = finish | (idx_next == eos_token).any(1)
|
317 |
inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
|
318 |
else:
|
319 |
finish = finish | (idx_next == eos_token).any(1)
|
320 |
+
inputs_ids = torch.cat(
|
321 |
+
[
|
322 |
+
inputs_ids,
|
323 |
+
idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq),
|
324 |
+
],
|
325 |
+
1,
|
326 |
+
)
|
327 |
|
328 |
end_idx = end_idx + (~finish).int()
|
329 |
+
|
330 |
+
inputs_ids = [
|
331 |
+
inputs_ids[idx, start_idx : start_idx + i]
|
332 |
+
for idx, i in enumerate(end_idx.int())
|
333 |
+
]
|
334 |
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
|
335 |
+
|
336 |
if return_hidden:
|
337 |
hiddens = torch.stack(hiddens, 1)
|
338 |
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
|
339 |
+
|
340 |
if not finish.all():
|
341 |
+
self.logger.warn(
|
342 |
+
f"Incomplete result. hit max_new_token: {max_new_token}"
|
343 |
+
)
|
344 |
+
|
345 |
return {
|
346 |
+
"ids": inputs_ids,
|
347 |
+
"attentions": attentions,
|
348 |
+
"hiddens": hiddens,
|
349 |
+
}
|
modules/ChatTTS/ChatTTS/utils/infer_utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import re
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
|
|
1 |
import re
|
2 |
+
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
|
modules/ChatTTS/ChatTTS/utils/io_utils.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
import logging
|
|
|
|
|
4 |
|
5 |
def get_latest_modified_file(directory):
|
6 |
logger = logging.getLogger(__name__)
|
7 |
-
|
8 |
-
files = [os.path.join(directory, f) for f in os.listdir(directory)]
|
9 |
if not files:
|
10 |
-
logger.log(logging.WARNING, f
|
11 |
return None
|
12 |
latest_file = max(files, key=os.path.getmtime)
|
13 |
|
14 |
-
return latest_file
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
|
5 |
def get_latest_modified_file(directory):
|
6 |
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
files = [os.path.join(directory, f) for f in os.listdir(directory)]
|
9 |
if not files:
|
10 |
+
logger.log(logging.WARNING, f"No files found in the directory: {directory}")
|
11 |
return None
|
12 |
latest_file = max(files, key=os.path.getmtime)
|
13 |
|
14 |
+
return latest_file
|
modules/Denoiser/AudioDenoiser.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
import logging
|
2 |
import math
|
3 |
from typing import Union
|
|
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
-
from torch import nn
|
7 |
-
from audio_denoiser.helpers.torch_helper import batched_apply
|
8 |
-
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
|
9 |
from audio_denoiser.helpers.audio_helper import (
|
10 |
create_spectrogram,
|
11 |
reconstruct_from_spectrogram,
|
12 |
)
|
|
|
|
|
|
|
|
|
13 |
|
14 |
_expected_t_std = 0.23
|
15 |
_recommended_backend = "soundfile"
|
|
|
1 |
import logging
|
2 |
import math
|
3 |
from typing import Union
|
4 |
+
|
5 |
import torch
|
6 |
import torchaudio
|
|
|
|
|
|
|
7 |
from audio_denoiser.helpers.audio_helper import (
|
8 |
create_spectrogram,
|
9 |
reconstruct_from_spectrogram,
|
10 |
)
|
11 |
+
from audio_denoiser.helpers.torch_helper import batched_apply
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
|
15 |
|
16 |
_expected_t_std = 0.23
|
17 |
_recommended_backend = "soundfile"
|
modules/Denoiser/AudioNosiseModel.py
CHANGED
@@ -1,12 +1,11 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
|
4 |
from audio_denoiser.modules.Permute import Permute
|
5 |
from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
|
6 |
from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
|
7 |
|
8 |
-
import json
|
9 |
-
|
10 |
|
11 |
class AudioNoiseModel(nn.Module):
|
12 |
def __init__(self, config: dict):
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
import torch
|
4 |
import torch.nn as nn
|
|
|
5 |
from audio_denoiser.modules.Permute import Permute
|
6 |
from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
|
7 |
from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
|
8 |
|
|
|
|
|
9 |
|
10 |
class AudioNoiseModel(nn.Module):
|
11 |
def __init__(self, config: dict):
|
modules/Enhancer/ResembleEnhance.py
CHANGED
@@ -1,20 +1,17 @@
|
|
1 |
import gc
|
|
|
|
|
|
|
2 |
from typing import Literal
|
3 |
|
4 |
import numpy as np
|
|
|
|
|
5 |
from modules.devices import devices
|
6 |
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
7 |
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
8 |
from modules.repos_static.resemble_enhance.inference import inference
|
9 |
-
|
10 |
-
import torch
|
11 |
-
|
12 |
from modules.utils.constants import MODELS_DIR
|
13 |
-
from pathlib import Path
|
14 |
-
|
15 |
-
from threading import Lock
|
16 |
-
|
17 |
-
import logging
|
18 |
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
@@ -155,8 +152,8 @@ def apply_audio_enhance(
|
|
155 |
|
156 |
|
157 |
if __name__ == "__main__":
|
158 |
-
import torchaudio
|
159 |
import gradio as gr
|
|
|
160 |
|
161 |
device = torch.device("cuda")
|
162 |
|
|
|
1 |
import gc
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from threading import Lock
|
5 |
from typing import Literal
|
6 |
|
7 |
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
from modules.devices import devices
|
11 |
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
|
12 |
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
|
13 |
from modules.repos_static.resemble_enhance.inference import inference
|
|
|
|
|
|
|
14 |
from modules.utils.constants import MODELS_DIR
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
152 |
|
153 |
|
154 |
if __name__ == "__main__":
|
|
|
155 |
import gradio as gr
|
156 |
+
import torchaudio
|
157 |
|
158 |
device = torch.device("cuda")
|
159 |
|
modules/SentenceSplitter.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import re
|
|
|
2 |
import zhon
|
3 |
|
4 |
|
|
|
1 |
import re
|
2 |
+
|
3 |
import zhon
|
4 |
|
5 |
|
modules/SynthesizeSegments.py
CHANGED
@@ -1,31 +1,37 @@
|
|
1 |
import copy
|
|
|
|
|
2 |
import re
|
|
|
|
|
|
|
3 |
from box import Box
|
4 |
from pydub import AudioSegment
|
5 |
-
|
6 |
-
from scipy.io.wavfile import write
|
7 |
-
import io
|
8 |
-
from modules.SentenceSplitter import SentenceSplitter
|
9 |
-
from modules.api.utils import calc_spk_style
|
10 |
-
from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
|
11 |
-
from modules.utils import rng
|
12 |
-
from modules.utils.audio import time_stretch, pitch_shift
|
13 |
from modules import generate_audio
|
|
|
14 |
from modules.normalization import text_normalize
|
15 |
-
import
|
16 |
-
import
|
17 |
-
|
18 |
-
from modules.
|
|
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
23 |
-
def audio_data_to_segment(audio_data, sr):
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
|
|
|
1 |
import copy
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
import re
|
5 |
+
from typing import List, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
from box import Box
|
9 |
from pydub import AudioSegment
|
10 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from modules import generate_audio
|
12 |
+
from modules.api.utils import calc_spk_style
|
13 |
from modules.normalization import text_normalize
|
14 |
+
from modules.SentenceSplitter import SentenceSplitter
|
15 |
+
from modules.speaker import Speaker
|
16 |
+
from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
|
17 |
+
from modules.utils import rng
|
18 |
+
from modules.utils.audio import pitch_shift, time_stretch
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
23 |
+
def audio_data_to_segment(audio_data: np.ndarray, sr: int):
|
24 |
+
"""
|
25 |
+
optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
|
26 |
+
"""
|
27 |
+
audio_data = (audio_data * 32767).astype(np.int16)
|
28 |
+
audio_segment = AudioSegment(
|
29 |
+
audio_data.tobytes(),
|
30 |
+
frame_rate=sr,
|
31 |
+
sample_width=audio_data.dtype.itemsize,
|
32 |
+
channels=1,
|
33 |
+
)
|
34 |
+
return audio_segment
|
35 |
|
36 |
|
37 |
def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
|
modules/api/Api.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
|
2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
|
4 |
import logging
|
5 |
|
|
|
|
|
6 |
from fastapi.staticfiles import StaticFiles
|
7 |
|
8 |
-
import fnmatch
|
9 |
-
|
10 |
|
11 |
def is_excluded(path, exclude_patterns):
|
12 |
"""
|
|
|
1 |
+
import fnmatch
|
|
|
|
|
2 |
import logging
|
3 |
|
4 |
+
from fastapi import FastAPI
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from fastapi.staticfiles import StaticFiles
|
7 |
|
|
|
|
|
8 |
|
9 |
def is_excluded(path, exclude_patterns):
|
10 |
"""
|
modules/api/api_setup.py
CHANGED
@@ -1,26 +1,24 @@
|
|
1 |
-
import logging
|
2 |
-
from modules.Enhancer.ResembleEnhance import load_enhancer
|
3 |
-
from modules.devices import devices
|
4 |
import argparse
|
|
|
5 |
|
6 |
-
from modules import config
|
7 |
-
from modules.models import load_chat_tts
|
8 |
-
from modules.utils import env
|
9 |
-
from modules import generate_audio
|
10 |
from modules.api.Api import APIManager
|
11 |
-
|
12 |
from modules.api.impl import (
|
13 |
-
style_api,
|
14 |
-
tts_api,
|
15 |
-
ssml_api,
|
16 |
google_api,
|
|
|
17 |
openai_api,
|
|
|
18 |
refiner_api,
|
19 |
speaker_api,
|
20 |
-
|
21 |
-
|
|
|
22 |
xtts_v2_api,
|
23 |
)
|
|
|
|
|
|
|
|
|
24 |
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
+
import logging
|
3 |
|
4 |
+
from modules import config, generate_audio
|
|
|
|
|
|
|
5 |
from modules.api.Api import APIManager
|
|
|
6 |
from modules.api.impl import (
|
|
|
|
|
|
|
7 |
google_api,
|
8 |
+
models_api,
|
9 |
openai_api,
|
10 |
+
ping_api,
|
11 |
refiner_api,
|
12 |
speaker_api,
|
13 |
+
ssml_api,
|
14 |
+
style_api,
|
15 |
+
tts_api,
|
16 |
xtts_v2_api,
|
17 |
)
|
18 |
+
from modules.devices import devices
|
19 |
+
from modules.Enhancer.ResembleEnhance import load_enhancer
|
20 |
+
from modules.models import load_chat_tts
|
21 |
+
from modules.utils import env
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
modules/api/impl/google_api.py
CHANGED
@@ -1,22 +1,18 @@
|
|
1 |
from typing import Union
|
2 |
-
from fastapi import HTTPException
|
3 |
|
|
|
4 |
from pydantic import BaseModel
|
5 |
|
6 |
-
|
7 |
from modules.api.Api import APIManager
|
8 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
9 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
10 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
11 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
12 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
13 |
-
|
14 |
from modules.speaker import Speaker, speaker_mgr
|
15 |
|
16 |
|
17 |
-
from modules.api import utils as api_utils
|
18 |
-
|
19 |
-
|
20 |
class SynthesisInput(BaseModel):
|
21 |
text: Union[str, None] = None
|
22 |
ssml: Union[str, None] = None
|
|
|
1 |
from typing import Union
|
|
|
2 |
|
3 |
+
from fastapi import HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
|
6 |
+
from modules.api import utils as api_utils
|
7 |
from modules.api.Api import APIManager
|
8 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
9 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
10 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
11 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
12 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
13 |
from modules.speaker import Speaker, speaker_mgr
|
14 |
|
15 |
|
|
|
|
|
|
|
16 |
class SynthesisInput(BaseModel):
|
17 |
text: Union[str, None] = None
|
18 |
ssml: Union[str, None] = None
|
modules/api/impl/handler/AudioHandler.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import base64
|
2 |
import io
|
|
|
3 |
import numpy as np
|
4 |
import soundfile as sf
|
5 |
|
6 |
-
from modules.api.impl.model.audio_model import AudioFormat
|
7 |
from modules.api import utils as api_utils
|
|
|
8 |
|
9 |
|
10 |
class AudioHandler:
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
+
|
4 |
import numpy as np
|
5 |
import soundfile as sf
|
6 |
|
|
|
7 |
from modules.api import utils as api_utils
|
8 |
+
from modules.api.impl.model.audio_model import AudioFormat
|
9 |
|
10 |
|
11 |
class AudioHandler:
|
modules/api/impl/handler/SSMLHandler.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
-
from fastapi import HTTPException
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
-
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
5 |
-
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
6 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
7 |
from modules.api.impl.model.audio_model import AdjustConfig
|
8 |
from modules.api.impl.model.chattts_model import InferConfig
|
9 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
10 |
from modules.normalization import text_normalize
|
11 |
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
|
|
12 |
from modules.utils import audio
|
13 |
|
14 |
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
from fastapi import HTTPException
|
3 |
|
|
|
|
|
4 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
5 |
from modules.api.impl.model.audio_model import AdjustConfig
|
6 |
from modules.api.impl.model.chattts_model import InferConfig
|
7 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
8 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
9 |
from modules.normalization import text_normalize
|
10 |
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
11 |
+
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
12 |
from modules.utils import audio
|
13 |
|
14 |
|
modules/api/impl/handler/TTSHandler.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
import numpy as np
|
2 |
-
|
3 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
4 |
from modules.api.impl.model.audio_model import AdjustConfig
|
5 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
6 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
7 |
from modules.normalization import text_normalize
|
8 |
from modules.speaker import Speaker
|
9 |
from modules.synthesize_audio import synthesize_audio
|
10 |
-
|
11 |
from modules.utils.audio import apply_prosody_to_audio_data
|
12 |
|
13 |
|
|
|
1 |
import numpy as np
|
2 |
+
|
3 |
from modules.api.impl.handler.AudioHandler import AudioHandler
|
4 |
from modules.api.impl.model.audio_model import AdjustConfig
|
5 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
6 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
7 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
8 |
from modules.normalization import text_normalize
|
9 |
from modules.speaker import Speaker
|
10 |
from modules.synthesize_audio import synthesize_audio
|
|
|
11 |
from modules.utils.audio import apply_prosody_to_audio_data
|
12 |
|
13 |
|
modules/api/impl/model/enhancer_model.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from typing import Literal
|
|
|
2 |
from pydantic import BaseModel
|
3 |
|
4 |
|
|
|
1 |
from typing import Literal
|
2 |
+
|
3 |
from pydantic import BaseModel
|
4 |
|
5 |
|
modules/api/impl/models_api.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
|
2 |
from modules.api import utils as api_utils
|
3 |
from modules.api.Api import APIManager
|
|
|
4 |
from modules.models import reload_chat_tts, unload_chat_tts
|
5 |
|
6 |
|
|
|
|
|
1 |
from modules.api import utils as api_utils
|
2 |
from modules.api.Api import APIManager
|
3 |
+
from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
|
4 |
from modules.models import reload_chat_tts, unload_chat_tts
|
5 |
|
6 |
|
modules/api/impl/openai_api.py
CHANGED
@@ -1,23 +1,18 @@
|
|
1 |
-
from
|
2 |
|
|
|
|
|
3 |
from numpy import clip
|
4 |
from pydantic import BaseModel, Field
|
5 |
-
from fastapi.responses import StreamingResponse
|
6 |
-
|
7 |
|
|
|
|
|
8 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
9 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
10 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
11 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
12 |
-
|
13 |
-
|
14 |
-
from typing import List, Optional
|
15 |
-
|
16 |
-
from modules.api import utils as api_utils
|
17 |
-
from modules.api.Api import APIManager
|
18 |
-
|
19 |
-
from modules.speaker import Speaker, speaker_mgr
|
20 |
from modules.data import styles_mgr
|
|
|
21 |
|
22 |
|
23 |
class AudioSpeechRequest(BaseModel):
|
|
|
1 |
+
from typing import List, Optional
|
2 |
|
3 |
+
from fastapi import Body, File, Form, HTTPException, UploadFile
|
4 |
+
from fastapi.responses import StreamingResponse
|
5 |
from numpy import clip
|
6 |
from pydantic import BaseModel, Field
|
|
|
|
|
7 |
|
8 |
+
from modules.api import utils as api_utils
|
9 |
+
from modules.api.Api import APIManager
|
10 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
11 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
12 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
13 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from modules.data import styles_mgr
|
15 |
+
from modules.speaker import Speaker, speaker_mgr
|
16 |
|
17 |
|
18 |
class AudioSpeechRequest(BaseModel):
|
modules/api/impl/ping_api.py
CHANGED
@@ -1,8 +1,7 @@
|
|
|
|
1 |
from modules.api import utils as api_utils
|
2 |
from modules.api.Api import APIManager
|
3 |
|
4 |
-
from modules import config
|
5 |
-
|
6 |
|
7 |
def setup(app: APIManager):
|
8 |
@app.get("/v1/ping", response_model=api_utils.BaseResponse)
|
|
|
1 |
+
from modules import config
|
2 |
from modules.api import utils as api_utils
|
3 |
from modules.api.Api import APIManager
|
4 |
|
|
|
|
|
5 |
|
6 |
def setup(app: APIManager):
|
7 |
@app.get("/v1/ping", response_model=api_utils.BaseResponse)
|
modules/api/impl/refiner_api.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
from fastapi import HTTPException
|
2 |
-
|
3 |
from pydantic import BaseModel
|
4 |
|
5 |
-
|
6 |
from modules import refiner
|
7 |
-
|
8 |
from modules.api import utils as api_utils
|
9 |
from modules.api.Api import APIManager
|
10 |
from modules.normalization import text_normalize
|
|
|
1 |
from fastapi import HTTPException
|
|
|
2 |
from pydantic import BaseModel
|
3 |
|
|
|
4 |
from modules import refiner
|
|
|
5 |
from modules.api import utils as api_utils
|
6 |
from modules.api.Api import APIManager
|
7 |
from modules.normalization import text_normalize
|
modules/api/impl/speaker_api.py
CHANGED
@@ -1,9 +1,10 @@
|
|
|
|
1 |
from fastapi import HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
-
|
4 |
-
from modules.speaker import speaker_mgr
|
5 |
from modules.api import utils as api_utils
|
6 |
from modules.api.Api import APIManager
|
|
|
7 |
|
8 |
|
9 |
class CreateSpeaker(BaseModel):
|
|
|
1 |
+
import torch
|
2 |
from fastapi import HTTPException
|
3 |
from pydantic import BaseModel
|
4 |
+
|
|
|
5 |
from modules.api import utils as api_utils
|
6 |
from modules.api.Api import APIManager
|
7 |
+
from modules.speaker import speaker_mgr
|
8 |
|
9 |
|
10 |
class CreateSpeaker(BaseModel):
|
modules/api/impl/ssml_api.py
CHANGED
@@ -1,19 +1,14 @@
|
|
1 |
-
from fastapi import
|
2 |
-
from fastapi.responses import StreamingResponse
|
3 |
-
|
4 |
from pydantic import BaseModel
|
5 |
-
from fastapi.responses import FileResponse
|
6 |
-
|
7 |
|
|
|
8 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
9 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
10 |
from modules.api.impl.model.chattts_model import InferConfig
|
11 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
12 |
|
13 |
|
14 |
-
from modules.api.Api import APIManager
|
15 |
-
|
16 |
-
|
17 |
class SSMLRequest(BaseModel):
|
18 |
ssml: str
|
19 |
format: AudioFormat = "mp3"
|
|
|
1 |
+
from fastapi import Body, HTTPException
|
2 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
3 |
from pydantic import BaseModel
|
|
|
|
|
4 |
|
5 |
+
from modules.api.Api import APIManager
|
6 |
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
7 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
8 |
from modules.api.impl.model.chattts_model import InferConfig
|
9 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
10 |
|
11 |
|
|
|
|
|
|
|
12 |
class SSMLRequest(BaseModel):
|
13 |
ssml: str
|
14 |
format: AudioFormat = "mp3"
|
modules/api/impl/style_api.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
from modules.data import styles_mgr
|
2 |
from modules.api import utils as api_utils
|
3 |
from modules.api.Api import APIManager
|
|
|
4 |
|
5 |
|
6 |
async def list_styles():
|
|
|
|
|
1 |
from modules.api import utils as api_utils
|
2 |
from modules.api.Api import APIManager
|
3 |
+
from modules.data import styles_mgr
|
4 |
|
5 |
|
6 |
async def list_styles():
|
modules/api/impl/tts_api.py
CHANGED
@@ -1,17 +1,13 @@
|
|
1 |
from fastapi import Depends, HTTPException, Query
|
2 |
-
from fastapi.responses import StreamingResponse
|
3 |
-
|
4 |
from pydantic import BaseModel
|
5 |
-
from fastapi.responses import FileResponse
|
6 |
-
|
7 |
|
|
|
|
|
8 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
9 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
10 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
11 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
12 |
-
|
13 |
-
from modules.api import utils as api_utils
|
14 |
-
from modules.api.Api import APIManager
|
15 |
from modules.speaker import Speaker
|
16 |
|
17 |
|
|
|
1 |
from fastapi import Depends, HTTPException, Query
|
2 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
|
|
3 |
from pydantic import BaseModel
|
|
|
|
|
4 |
|
5 |
+
from modules.api import utils as api_utils
|
6 |
+
from modules.api.Api import APIManager
|
7 |
from modules.api.impl.handler.TTSHandler import TTSHandler
|
8 |
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
9 |
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
10 |
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
|
|
|
11 |
from modules.speaker import Speaker
|
12 |
|
13 |
|
modules/api/impl/xtts_v2_api.py
CHANGED
@@ -1,19 +1,17 @@
|
|
1 |
import io
|
|
|
|
|
|
|
2 |
from fastapi import HTTPException
|
3 |
from fastapi.responses import StreamingResponse
|
4 |
from pydantic import BaseModel
|
5 |
-
from modules.api import utils as api_utils
|
6 |
-
from modules.api.Api import APIManager
|
7 |
-
|
8 |
-
import soundfile as sf
|
9 |
|
10 |
from modules import config
|
|
|
|
|
11 |
from modules.normalization import text_normalize
|
12 |
from modules.speaker import speaker_mgr
|
13 |
from modules.synthesize_audio import synthesize_audio
|
14 |
-
|
15 |
-
import logging
|
16 |
-
|
17 |
from modules.utils.audio import apply_prosody_to_audio_data
|
18 |
|
19 |
logger = logging.getLogger(__name__)
|
|
|
1 |
import io
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import soundfile as sf
|
5 |
from fastapi import HTTPException
|
6 |
from fastapi.responses import StreamingResponse
|
7 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
8 |
|
9 |
from modules import config
|
10 |
+
from modules.api import utils as api_utils
|
11 |
+
from modules.api.Api import APIManager
|
12 |
from modules.normalization import text_normalize
|
13 |
from modules.speaker import speaker_mgr
|
14 |
from modules.synthesize_audio import synthesize_audio
|
|
|
|
|
|
|
15 |
from modules.utils.audio import apply_prosody_to_audio_data
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
modules/api/utils.py
CHANGED
@@ -1,14 +1,10 @@
|
|
1 |
-
from pydantic import BaseModel
|
2 |
from typing import Any, Union
|
3 |
|
4 |
-
|
5 |
-
from modules.speaker import speaker_mgr
|
6 |
-
|
7 |
-
|
8 |
-
from modules.data import styles_mgr
|
9 |
-
|
10 |
from pydub import AudioSegment
|
11 |
|
|
|
|
|
12 |
from modules.ssml import merge_prompt
|
13 |
|
14 |
|
|
|
|
|
1 |
from typing import Any, Union
|
2 |
|
3 |
+
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
4 |
from pydub import AudioSegment
|
5 |
|
6 |
+
from modules.data import styles_mgr
|
7 |
+
from modules.speaker import speaker_mgr
|
8 |
from modules.ssml import merge_prompt
|
9 |
|
10 |
|
modules/api/worker.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
import os
|
|
|
4 |
import dotenv
|
5 |
from fastapi import FastAPI
|
6 |
|
@@ -12,6 +13,7 @@ logging.basicConfig(
|
|
12 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
13 |
)
|
14 |
|
|
|
15 |
from modules.api.api_setup import (
|
16 |
process_api_args,
|
17 |
process_model_args,
|
@@ -20,7 +22,6 @@ from modules.api.api_setup import (
|
|
20 |
setup_uvicon_args,
|
21 |
)
|
22 |
from modules.api.app_config import app_description, app_title, app_version
|
23 |
-
from modules import config
|
24 |
from modules.utils.torch_opt import configure_torch_optimizations
|
25 |
|
26 |
dotenv.load_dotenv(
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
import os
|
4 |
+
|
5 |
import dotenv
|
6 |
from fastapi import FastAPI
|
7 |
|
|
|
13 |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
14 |
)
|
15 |
|
16 |
+
from modules import config
|
17 |
from modules.api.api_setup import (
|
18 |
process_api_args,
|
19 |
process_model_args,
|
|
|
22 |
setup_uvicon_args,
|
23 |
)
|
24 |
from modules.api.app_config import app_description, app_title, app_version
|
|
|
25 |
from modules.utils.torch_opt import configure_torch_optimizations
|
26 |
|
27 |
dotenv.load_dotenv(
|
modules/config.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import sys
|
2 |
|
3 |
import torch
|
4 |
-
from modules.utils.JsonObject import JsonObject
|
5 |
|
6 |
-
from modules.utils import
|
|
|
7 |
|
8 |
# TODO impl RuntimeEnvVars() class
|
9 |
runtime_env_vars = JsonObject({})
|
|
|
1 |
import sys
|
2 |
|
3 |
import torch
|
|
|
4 |
|
5 |
+
from modules.utils import ffmpeg, git
|
6 |
+
from modules.utils.JsonObject import JsonObject
|
7 |
|
8 |
# TODO impl RuntimeEnvVars() class
|
9 |
runtime_env_vars = JsonObject({})
|
modules/data.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from modules.utils.CsvMgr import BaseManager
|
2 |
|
3 |
-
|
4 |
# speakers_mgr = BaseManager("./data/speakers.csv")
|
5 |
styles_mgr = BaseManager("./data/styles.csv")
|
6 |
|
|
|
1 |
from modules.utils.CsvMgr import BaseManager
|
2 |
|
|
|
3 |
# speakers_mgr = BaseManager("./data/speakers.csv")
|
4 |
styles_mgr = BaseManager("./data/styles.csv")
|
5 |
|
modules/denoise.py
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
import os
|
2 |
from typing import Union
|
3 |
|
|
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
-
from modules.Denoiser.AudioDenoiser import AudioDenoiser
|
7 |
-
|
8 |
-
from modules.utils.constants import MODELS_DIR
|
9 |
|
|
|
10 |
from modules.devices import devices
|
11 |
-
|
12 |
-
import soundfile as sf
|
13 |
|
14 |
ad: Union[AudioDenoiser, None] = None
|
15 |
|
|
|
1 |
import os
|
2 |
from typing import Union
|
3 |
|
4 |
+
import soundfile as sf
|
5 |
import torch
|
6 |
import torchaudio
|
|
|
|
|
|
|
7 |
|
8 |
+
from modules.Denoiser.AudioDenoiser import AudioDenoiser
|
9 |
from modules.devices import devices
|
10 |
+
from modules.utils.constants import MODELS_DIR
|
|
|
11 |
|
12 |
ad: Union[AudioDenoiser, None] = None
|
13 |
|
modules/devices/devices.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
|
2 |
import sys
|
|
|
|
|
3 |
import torch
|
4 |
-
from modules import config
|
5 |
|
6 |
-
import
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
|
|
1 |
+
import logging
|
2 |
import sys
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
import torch
|
|
|
6 |
|
7 |
+
from modules import config
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
modules/devices/mac_devices.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
import torch
|
2 |
import logging
|
3 |
-
|
|
|
4 |
import torch.backends
|
5 |
import torch.backends.mps
|
|
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
|
|
|
|
1 |
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
import torch.backends
|
5 |
import torch.backends.mps
|
6 |
+
from packaging import version
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
modules/ffmpeg_env.py
CHANGED
@@ -1,6 +1,7 @@
|
|
|
|
1 |
import os
|
|
|
2 |
from modules.utils.constants import ROOT_DIR
|
3 |
-
import logging
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
|
|
1 |
+
import logging
|
2 |
import os
|
3 |
+
|
4 |
from modules.utils.constants import ROOT_DIR
|
|
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
|
modules/finetune/train_speaker.py
CHANGED
@@ -3,9 +3,10 @@ import torch.nn.functional as F
|
|
3 |
import transformers
|
4 |
|
5 |
from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
|
6 |
-
from modules.finetune.utils.output import get_ansi_len, output_iter
|
7 |
-
|
8 |
from .utils.dataset import AudioCollator, XzListTar
|
|
|
9 |
from .utils.model import quantize
|
10 |
|
11 |
IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
|
@@ -201,11 +202,13 @@ def train_speaker_embeddings(
|
|
201 |
if __name__ == "__main__":
|
202 |
import argparse
|
203 |
import os
|
204 |
-
import numpy as np
|
205 |
import pathlib
|
206 |
-
|
207 |
-
|
|
|
208 |
from modules import config
|
|
|
|
|
209 |
from modules.speaker import Speaker
|
210 |
|
211 |
config.runtime_env_vars.no_half = True
|
|
|
3 |
import transformers
|
4 |
|
5 |
from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
|
6 |
+
from modules.finetune.utils.output import ansi, get_ansi_len, output_iter
|
7 |
+
|
8 |
from .utils.dataset import AudioCollator, XzListTar
|
9 |
+
from .utils.logger import MetricLogger
|
10 |
from .utils.model import quantize
|
11 |
|
12 |
IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
|
|
|
202 |
if __name__ == "__main__":
|
203 |
import argparse
|
204 |
import os
|
|
|
205 |
import pathlib
|
206 |
+
|
207 |
+
import numpy as np
|
208 |
+
|
209 |
from modules import config
|
210 |
+
from modules.devices import devices
|
211 |
+
from modules.models import load_chat_tts
|
212 |
from modules.speaker import Speaker
|
213 |
|
214 |
config.runtime_env_vars.no_half = True
|
modules/finetune/utils/dataset.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
import
|
2 |
import functools
|
3 |
-
import json
|
4 |
-
import tarfile
|
5 |
import io
|
|
|
6 |
import logging
|
7 |
-
import
|
|
|
8 |
import typing
|
9 |
|
10 |
import torch.utils.data
|
11 |
import torchaudio
|
12 |
-
from torchvision.datasets.utils import download_url
|
13 |
import transformers
|
14 |
import vocos
|
|
|
15 |
|
16 |
from modules.ChatTTS.ChatTTS.utils.infer_utils import (
|
17 |
-
count_invalid_characters,
|
18 |
apply_character_map,
|
|
|
19 |
)
|
20 |
|
21 |
|
|
|
1 |
+
import abc
|
2 |
import functools
|
|
|
|
|
3 |
import io
|
4 |
+
import json
|
5 |
import logging
|
6 |
+
import os
|
7 |
+
import tarfile
|
8 |
import typing
|
9 |
|
10 |
import torch.utils.data
|
11 |
import torchaudio
|
|
|
12 |
import transformers
|
13 |
import vocos
|
14 |
+
from torchvision.datasets.utils import download_url
|
15 |
|
16 |
from modules.ChatTTS.ChatTTS.utils.infer_utils import (
|
|
|
17 |
apply_character_map,
|
18 |
+
count_invalid_characters,
|
19 |
)
|
20 |
|
21 |
|
modules/finetune/utils/logger.py
CHANGED
@@ -3,15 +3,14 @@
|
|
3 |
import statistics
|
4 |
import time
|
5 |
from collections import defaultdict, deque
|
6 |
-
from tqdm import tqdm as tqdm_class
|
7 |
-
|
8 |
from typing import Generator, Iterable, TypeVar
|
9 |
-
from typing_extensions import Self
|
10 |
|
11 |
import torch
|
12 |
import torch.distributed as dist
|
|
|
|
|
13 |
|
14 |
-
from .output import ansi,
|
15 |
|
16 |
__all__ = ["SmoothedValue", "MetricLogger"]
|
17 |
|
|
|
3 |
import statistics
|
4 |
import time
|
5 |
from collections import defaultdict, deque
|
|
|
|
|
6 |
from typing import Generator, Iterable, TypeVar
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.distributed as dist
|
10 |
+
from tqdm import tqdm as tqdm_class
|
11 |
+
from typing_extensions import Self
|
12 |
|
13 |
+
from .output import ansi, get_ansi_len, prints
|
14 |
|
15 |
__all__ = ["SmoothedValue", "MetricLogger"]
|
16 |
|
modules/generate_audio.py
CHANGED
@@ -1,18 +1,15 @@
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
|
4 |
-
from modules
|
5 |
-
from modules.utils.SeedContext import SeedContext
|
6 |
-
|
7 |
-
from modules import models, config
|
8 |
-
|
9 |
-
import logging
|
10 |
-
import gc
|
11 |
-
|
12 |
from modules.devices import devices
|
13 |
-
from
|
14 |
-
|
15 |
from modules.utils.cache import conditional_cache
|
|
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
|
8 |
+
from modules import config, models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from modules.devices import devices
|
10 |
+
from modules.speaker import Speaker
|
|
|
11 |
from modules.utils.cache import conditional_cache
|
12 |
+
from modules.utils.SeedContext import SeedContext
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
modules/models.py
CHANGED
@@ -1,13 +1,13 @@
|
|
|
|
|
|
1 |
import threading
|
|
|
2 |
import torch
|
3 |
-
|
4 |
from modules import config
|
|
|
5 |
from modules.devices import devices
|
6 |
|
7 |
-
import logging
|
8 |
-
import gc
|
9 |
-
|
10 |
-
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
chat_tts = None
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
import threading
|
4 |
+
|
5 |
import torch
|
6 |
+
|
7 |
from modules import config
|
8 |
+
from modules.ChatTTS import ChatTTS
|
9 |
from modules.devices import devices
|
10 |
|
|
|
|
|
|
|
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
chat_tts = None
|
modules/normalization.py
CHANGED
@@ -1,9 +1,11 @@
|
|
|
|
1 |
from functools import lru_cache
|
2 |
-
|
3 |
import emojiswitch
|
4 |
-
|
5 |
from modules import models
|
6 |
-
import
|
|
|
7 |
|
8 |
# 是否关闭 unk token 检查
|
9 |
# NOTE: 单测的时候用于跳过模型加载
|
|
|
1 |
+
import re
|
2 |
from functools import lru_cache
|
3 |
+
|
4 |
import emojiswitch
|
5 |
+
|
6 |
from modules import models
|
7 |
+
from modules.utils.markdown import markdown_to_text
|
8 |
+
from modules.utils.zh_normalization.text_normlization import *
|
9 |
|
10 |
# 是否关闭 unk token 检查
|
11 |
# NOTE: 单测的时候用于跳过模型加载
|
modules/prompts/news_oral_prompt.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
|
4 |
-
|
5 |
同时,适当的添加一些 附语言 标签为文本增加多样性
|
6 |
|
7 |
目前可以使用的附语言标签如下:
|
@@ -10,5 +10,24 @@
|
|
10 |
- `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
|
11 |
- `[lbreak]`: 表示一个长停顿一般表示段落结束
|
12 |
|
13 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
{{USER_INPUT}}
|
|
|
1 |
+
#任务要求
|
2 |
+
任务:新闻稿口播化
|
3 |
|
4 |
+
你需要将一个新闻稿改写为口语化的口播文本,以提供给新闻主播在晚间新闻节目中播报
|
5 |
同时,适当的添加一些 附语言 标签为文本增加多样性
|
6 |
|
7 |
目前可以使用的附语言标签如下:
|
|
|
10 |
- `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
|
11 |
- `[lbreak]`: 表示一个长停顿一般表示段落结束
|
12 |
|
13 |
+
# examples
|
14 |
+
## case 1
|
15 |
+
- input: `天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖`
|
16 |
+
- output: `天气预报显示,今天会有小雨,请大家出门时记得带伞[uv_break]。那降温的天气[uv_break]也提醒我们要适时添衣保暖[lbreak]`
|
17 |
+
|
18 |
+
## case 2
|
19 |
+
- input: `请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯`
|
20 |
+
- output: `请注意啊,这个电梯将在下午两点进行[uv_break]例行维护[uv_break],预计需要一个小时的时间[uv_break],请大家在此期间使用楼梯[lbreak]`
|
21 |
+
|
22 |
+
## case 3
|
23 |
+
- input: `它的任务是简化记者编辑的工作流程。记者写稿时可以用标签来标明关键词、标题或主题。随着时间推移,数据积累到一定程度后,机器编辑就能自动识别这些标签`
|
24 |
+
- output: `它的任务呢是简化记者编辑的工作流程[uv_break]。记者写稿时呢可以用标签来标明关键词啊、标题啊或主题[uv_break]。那随着时间推移呢,数据积累到一定程度后[uv_break],机器编辑就能自动识别这些标签[uv_break]`
|
25 |
+
|
26 |
+
## case 4
|
27 |
+
- input: `有一天,小明问他爸爸:“爸爸,我是不是傻孩子啊?”
|
28 |
+
|
29 |
+
爸爸说:“傻孩子,你怎么会是傻孩子呢?”`
|
30 |
+
- output: `然后有一天呢,小明问他[uv_break]爸爸[uv_break],爸爸,我是不是傻孩[uv_break]子啊?爸爸说,傻孩[laugh]子啊,你怎么会是傻孩子呢[laugh]?`
|
31 |
+
|
32 |
+
# 用户输入
|
33 |
{{USER_INPUT}}
|
modules/refiner.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
|
|
|
4 |
from modules.utils.SeedContext import SeedContext
|
5 |
|
6 |
-
from modules import models, config
|
7 |
-
|
8 |
|
9 |
@torch.inference_mode()
|
10 |
def refine_text(
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
|
4 |
+
from modules import config, models
|
5 |
from modules.utils.SeedContext import SeedContext
|
6 |
|
|
|
|
|
7 |
|
8 |
@torch.inference_mode()
|
9 |
def refine_text(
|
modules/repos_static/resemble_enhance/common.py
CHANGED
@@ -42,7 +42,9 @@ class Normalizer(nn.Module):
|
|
42 |
self.running_var_unsafe = x.var()
|
43 |
else:
|
44 |
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
45 |
-
self.running_var_unsafe = self._ema(
|
|
|
|
|
46 |
|
47 |
def forward(self, x: Tensor, update=True):
|
48 |
if self.training and update:
|
|
|
42 |
self.running_var_unsafe = x.var()
|
43 |
else:
|
44 |
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
45 |
+
self.running_var_unsafe = self._ema(
|
46 |
+
self.running_var_unsafe, (x - self.running_mean).pow(2).mean()
|
47 |
+
)
|
48 |
|
49 |
def forward(self, x: Tensor, update=True):
|
50 |
if self.training and update:
|
modules/repos_static/resemble_enhance/data/dataset.py
CHANGED
@@ -44,7 +44,9 @@ def praat_augment(wav, sr):
|
|
44 |
sound = parselmouth.Sound(wav, sr)
|
45 |
formant_shift_ratio = random.uniform(1.1, 1.5)
|
46 |
pitch_range_factor = random.uniform(0.5, 2.0)
|
47 |
-
sound = parselmouth.praat.call(
|
|
|
|
|
48 |
wav = np.array(sound.values)[0].astype(np.float32)
|
49 |
return wav
|
50 |
|
@@ -73,7 +75,9 @@ class Dataset(DatasetBase):
|
|
73 |
if len(self.bg_paths) == 0:
|
74 |
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
75 |
|
76 |
-
logger.info(
|
|
|
|
|
77 |
|
78 |
self.training = training
|
79 |
self.max_retries = max_retries
|
@@ -121,7 +125,9 @@ class Dataset(DatasetBase):
|
|
121 |
fg_path = self.fg_paths[index]
|
122 |
|
123 |
if self.training and random.random() < self.silent_fg_prob:
|
124 |
-
fg_wav = np.zeros(
|
|
|
|
|
125 |
else:
|
126 |
fg_wav = self._load_wav(fg_path)
|
127 |
if random.random() < self.hp.praat_augment_prob and self.training:
|
@@ -132,14 +138,20 @@ class Dataset(DatasetBase):
|
|
132 |
fg_dwav = None
|
133 |
bg_dwav = None
|
134 |
else:
|
135 |
-
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(
|
|
|
|
|
136 |
if self.training:
|
137 |
bg_path = random.choice(self.bg_paths)
|
138 |
else:
|
139 |
# Deterministic for validation
|
140 |
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
141 |
-
bg_wav = self._load_wav(
|
142 |
-
|
|
|
|
|
|
|
|
|
143 |
|
144 |
return dict(
|
145 |
fg_wav=fg_wav,
|
@@ -154,7 +166,9 @@ class Dataset(DatasetBase):
|
|
154 |
return self._getitem_unsafe(index)
|
155 |
except Exception as e:
|
156 |
if i == self.max_retries - 1:
|
157 |
-
raise RuntimeError(
|
|
|
|
|
158 |
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
159 |
index = np.random.randint(0, len(self))
|
160 |
|
|
|
44 |
sound = parselmouth.Sound(wav, sr)
|
45 |
formant_shift_ratio = random.uniform(1.1, 1.5)
|
46 |
pitch_range_factor = random.uniform(0.5, 2.0)
|
47 |
+
sound = parselmouth.praat.call(
|
48 |
+
sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0
|
49 |
+
)
|
50 |
wav = np.array(sound.values)[0].astype(np.float32)
|
51 |
return wav
|
52 |
|
|
|
75 |
if len(self.bg_paths) == 0:
|
76 |
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
77 |
|
78 |
+
logger.info(
|
79 |
+
f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files"
|
80 |
+
)
|
81 |
|
82 |
self.training = training
|
83 |
self.max_retries = max_retries
|
|
|
125 |
fg_path = self.fg_paths[index]
|
126 |
|
127 |
if self.training and random.random() < self.silent_fg_prob:
|
128 |
+
fg_wav = np.zeros(
|
129 |
+
int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32
|
130 |
+
)
|
131 |
else:
|
132 |
fg_wav = self._load_wav(fg_path)
|
133 |
if random.random() < self.hp.praat_augment_prob and self.training:
|
|
|
138 |
fg_dwav = None
|
139 |
bg_dwav = None
|
140 |
else:
|
141 |
+
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(
|
142 |
+
np.float32
|
143 |
+
)
|
144 |
if self.training:
|
145 |
bg_path = random.choice(self.bg_paths)
|
146 |
else:
|
147 |
# Deterministic for validation
|
148 |
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
149 |
+
bg_wav = self._load_wav(
|
150 |
+
bg_path, length=len(fg_wav), random_crop=self.training
|
151 |
+
)
|
152 |
+
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(
|
153 |
+
np.float32
|
154 |
+
)
|
155 |
|
156 |
return dict(
|
157 |
fg_wav=fg_wav,
|
|
|
166 |
return self._getitem_unsafe(index)
|
167 |
except Exception as e:
|
168 |
if i == self.max_retries - 1:
|
169 |
+
raise RuntimeError(
|
170 |
+
f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries"
|
171 |
+
) from e
|
172 |
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
173 |
index = np.random.randint(0, len(self))
|
174 |
|
modules/repos_static/resemble_enhance/data/distorter/base.py
CHANGED
@@ -2,8 +2,8 @@ import itertools
|
|
2 |
import os
|
3 |
import random
|
4 |
import time
|
5 |
-
from typing import Union
|
6 |
import warnings
|
|
|
7 |
|
8 |
import numpy as np
|
9 |
|
|
|
2 |
import os
|
3 |
import random
|
4 |
import time
|
|
|
5 |
import warnings
|
6 |
+
from typing import Union
|
7 |
|
8 |
import numpy as np
|
9 |
|