File size: 2,292 Bytes
96ee597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Vocoder wrapper.

Copyright PolyAI Limited.
"""
import enum

import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
from speechtokenizer import SpeechTokenizer


class VocoderType(enum.Enum):
    SPEECHTOKENIZER = ("SPEECHTOKENIZER", 320)

    def __init__(self, name, compression_ratio):
        self._name_ = name
        self.compression_ratio = compression_ratio

    def get_vocoder(self, ckpt_path, config_path, **kwargs):
        if self.name == "SPEECHTOKENIZER":
            if ckpt_path:
                vocoder = STWrapper(ckpt_path, config_path)
            else:
                vocoder = STWrapper()
        else:
            raise ValueError(f"Unknown vocoder type {self.name}")
        return vocoder


class STWrapper(nn.Module):
    def __init__(
            self, 
            ckpt_path: str = './ckpt/speechtokenizer/SpeechTokenizer.pt',
            config_path = './ckpt/speechtokenizer/config.json',
        ):
        super().__init__()
        self.model = SpeechTokenizer.load_from_checkpoint(
            config_path, ckpt_path)

    def eval(self):
        self.model.eval()

    @torch.no_grad()
    def decode(self, codes: torch.Tensor, verbose: bool = False):
        original_device = codes.device

        codes = codes.to(self.device)
        audio_array = self.model.decode(codes)

        return audio_array.to(original_device)

    def decode_to_file(self, codes_path, out_path) -> None:
        codes = np.load(codes_path)
        codes = torch.from_numpy(codes)
        wav = self.decode(codes).cpu().numpy()
        sf.write(out_path, wav, samplerate=self.model.sample_rate)

    @torch.no_grad()
    def encode(self, wav, verbose=False, n_quantizers: int = None):
        original_device = wav.device
        wav = wav.to(self.device)
        codes = self.model.encode(wav) # codes: (n_q, B, T)
        return codes.to(original_device)

    def encode_to_file(self, wav_path, out_path) -> None:
        wav, _ = sf.read(wav_path, dtype='float32')
        wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0)
        codes = self.encode(wav).cpu().numpy()
        np.save(out_path, codes)

    def remove_weight_norm(self):
        pass

    @property
    def device(self):
        return next(self.model.parameters()).device