S3TVR-Demo / inference_functions.py
yalsaffar's picture
init
aa7cb02
import time
import torch
import torchaudio
import noisereduce as nr
import numpy as np
from models.nllb import nllb_translate
def translate(model_nllb, tokenizer_nllb, text, target_lang):
print("Processing translation...")
start_time = time.time()
translation = nllb_translate(model_nllb, tokenizer_nllb, text, target_lang)
print("Translation:", translation)
print("Translation time:", time.time() - start_time)
return translation
def just_inference(model, original_path, output_dir, text, lang):
print("Inference...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
path_to_save = output_dir
t0 = time.time()
try:
# Load the audio
print("Loading audio...")
wav, sr = torchaudio.load(original_path)
print(f"Loaded audio with sample rate: {sr}")
wav = wav.squeeze().numpy()
print(f"Audio shape after squeezing: {wav.shape}")
# Apply noise reduction
print("Applying noise reduction...")
reduced_noise_audio = nr.reduce_noise(y=wav, sr=sr)
reduced_noise_audio = torch.tensor(reduced_noise_audio).unsqueeze(0)
print(f"Reduced noise audio shape: {reduced_noise_audio.shape}")
# Move the reduced noise audio to the correct device
reduced_noise_audio = reduced_noise_audio.to(device)
print("Getting conditioning latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[original_path])
print("Got conditioning latents.")
print("Starting inference stream...")
chunks = model.inference_stream(
text,
lang,
gpt_cond_latent,
speaker_embedding,
stream_chunk_size=15,
speed=0.95
)
print("Inference stream started.")
full_audio = torch.Tensor().to(device)
for i, chunk in enumerate(chunks):
try:
if i == 1:
time_to_first_chunk = time.time() - t0
print(f"Time to first chunk: {time_to_first_chunk}")
full_audio = torch.cat((full_audio, chunk.squeeze().to(device)), dim=-1)
print(f"Processed chunk {i}, chunk shape: {chunk.shape}")
except Exception as e:
print(f"Error processing chunk {i}: {e}")
raise
# Move full_audio to CPU before saving
full_audio = full_audio.cpu()
print(f"Saving full audio to {path_to_save}...")
torchaudio.save(path_to_save, full_audio.unsqueeze(0), 24000)
print("Audio saved.")
print("Inference finished")
return full_audio
except Exception as e:
print(f"Error during processing: {e}")
raise