Text Generation
Transformers
PyTorch
English
llama
text-generation-inference
Inference Endpoints

Inference:

device = "cuda"

n_codebooks_tts = 3
n_codebooks_asr = 1

start_audio_token = "<|start_of_audio|>"
end_audio_token = "<|end_of_audio|>"
end_sequence_token = "<|end_of_text|>"

base_model = "Vikhrmodels/salt-asr_speech_1_wav_1_tts_speech_3_text-10k"


def decode_tts(tokens, quantizer, n_codebooks, n_original_tokens, start_audio_token_id, end_audio_token_id):
    # find start and end indices of audio tokens
    start = torch.nonzero(tokens == start_audio_token_id)
    end = torch.nonzero(tokens == end_audio_token_id)

    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]

    # subtract length of original vocabulary -> tokens in range [0, 1024)
    audio_tokens = tokens[start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks

    if reminder:
        # pad if last frame is incomplete
        pad_tokens = torch.zeros(n_codebooks - reminder, device="cuda")
        audio_tokens = torch.cat([audio_tokens, pad_tokens], dim=0)

    transposed = audio_tokens.view(-1, n_codebooks).t()
    codes = transposed.view(n_codebooks, 1, -1).to(device)

    audio = quantizer.decode(codes).squeeze(0)

    del tokens
    del audio_tokens
    torch.cuda.empty_cache()

    return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)


def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
    text_tokenized = tokenizer(text, return_tensors="pt")
    text_input_tokens = text_tokenized["input_ids"].to(device)

    soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)

    text_tokens = torch.cat([text_input_tokens, soa], dim=1)
    attention_mask = torch.ones(text_tokens.size(), device=device)

    output_audio_tokens = model.generate(
        text_tokens,
        attention_mask=attention_mask,
        max_new_tokens=max_seq_length,
        top_k=top_k,
        do_sample=True,
        temperature=0.1,
        repetition_penalty=1.1,
        length_penalty=1.2,
        no_repeat_ngram_size=3,
    )

    audio_signal = decode_tts(output_audio_tokens[0], quantizer, 3, len(tokenizer), soa, eoa)

    return audio_signal


def infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech, quantizer_wav, max_seq_length=1024, top_k=20):
    audio_data, sample_rate = torchaudio.load(audio_path)

    audio = audio_data.view(1, -1).float().to(device)
    bandwidth_id = torch.tensor([0])

    codes_semantics = quantizer_speech.encode(audio.reshape(1, 1, -1))
    raw_semantic_tokens = codes_semantics + len(tokenizer)
    raw_semantic_tokens = raw_semantic_tokens[:1].view(1, -1)

    _, codes = quantizer_wav.encode_infer(audio, bandwidth_id=bandwidth_id)
    raw_acoustic_tokens = codes + len(tokenizer) + 1024
    raw_acoustic_tokens = raw_acoustic_tokens.view(1, -1)

    audio_tokens = torch.cat([raw_semantic_tokens, raw_acoustic_tokens], dim=1)

    soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    audio_tokens = torch.cat([soa, audio_tokens, eoa], dim=1)
    tokens = torch.cat([audio_tokens], dim=1)

    attention_mask = torch.ones(tokens.size(), device=device)

    output_text_tokens = model.generate(
        tokens,
        attention_mask=attention_mask,
        max_new_tokens=max_seq_length,
        do_sample=True,
        temperature=0.1,
        top_p=0.9,
        top_k=top_k,
    )

    output_text_tokens = output_text_tokens.cpu()[0]
    output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
    decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)

    return decoded_text


tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    cache_dir=".",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    device_map={"": 0}
)

quantizer_speech = SpeechTokenizer.load_from_checkpoint("speechtokenizer/config.json",
                                                        "speechtokenizer/SpeechTokenizer.pt")
quantizer_speech = quantizer_speech.eval().to(device)
codebook_size = quantizer_speech.quantizer.bins

quantizer_wav = WavTokenizer.from_pretrained0802("wavtokenizer/config.yaml",
                                                 "wavtokenizer/WavTokenizer_small_600_24k_4096.ckpt")
quantizer_wav = quantizer_wav.to(device)

text = ("Say 'COUNT NUMBERS FROM ONE TO TEN' with a male speaker delivers a very monotone and "
        "low-pitched speech with a moderate speed in a setting with almost no noise, "
        "creating a clear and quiet recording.")

audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer_speech, top_k=60)
audio_signal.write("output.wav")

audio_path = "./input.wav"
generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer_speech, quantizer_wav, top_k=10)
print(generated_text)
Downloads last month
15
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the model is not deployed on the HF Inference API.

Model tree for Vikhrmodels/salt-asr_speech_1_wav_1_tts_speech_3_text-10k

Finetuned
(98)
this model

Datasets used to train Vikhrmodels/salt-asr_speech_1_wav_1_tts_speech_3_text-10k