Spaces:
Runtime error
Runtime error
File size: 2,292 Bytes
96ee597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
"""Vocoder wrapper.
Copyright PolyAI Limited.
"""
import enum
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
from speechtokenizer import SpeechTokenizer
class VocoderType(enum.Enum):
SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320)
def __init__(self, name, compression_ratio):
self._name_ = name
self.compression_ratio = compression_ratio
def get_vocoder(self, ckpt_path, config_path, **kwargs):
if self.name == "SPEECHTOKENIZER":
if ckpt_path:
vocoder = STWrapper(ckpt_path, config_path)
else:
vocoder = STWrapper()
else:
raise ValueError(f"Unknown vocoder type {self.name}")
return vocoder
class STWrapper(nn.Module):
def __init__(
self,
ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt',
config_path = './ckpt/speechtokenizer/config.json',
):
super().__init__()
self.model = SpeechTokenizer.load_from_checkpoint(
config_path, ckpt_path)
def eval(self):
self.model.eval()
@torch.no_grad()
def decode(self, codes: torch.Tensor, verbose: bool = False):
original_device = codes.device
codes = codes.to(self.device)
audio_array = self.model.decode(codes)
return audio_array.to(original_device)
def decode_to_file(self, codes_path, out_path) -> None:
codes = np.load(codes_path)
codes = torch.from_numpy(codes)
wav = self.decode(codes).cpu().numpy()
sf.write(out_path, wav, samplerate=self.model.sample_rate)
@torch.no_grad()
def encode(self, wav, verbose=False, n_quantizers: int = None):
original_device = wav.device
wav = wav.to(self.device)
codes = self.model.encode(wav) # codes: (n_q, B, T)
return codes.to(original_device)
def encode_to_file(self, wav_path, out_path) -> None:
wav, _ = sf.read(wav_path, dtype='float32')
wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0)
codes = self.encode(wav).cpu().numpy()
np.save(out_path, codes)
def remove_weight_norm(self):
pass
@property
def device(self):
return next(self.model.parameters()).device
|