S3TVR-Demo / app.py
yalsaffar's picture
init
aa7cb02
import fastapi
import uvicorn
from fastapi import File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from load_models import get_nllb_model_and_tokenizer, get_xtts_model
from inference_functions import translate, just_inference
import os
import torch
# Set GPU memory fraction
torch.cuda.set_per_process_memory_fraction(0.75, 0)
# Load models
model_nllb, tokenizer_nllb = get_nllb_model_and_tokenizer()
model_xtts = get_xtts_model()
app = fastapi.FastAPI()
@app.get("/health")
def health_check():
return {"status": "ok"}
@app.post("/translate/")
def translate_text(text: str = Form(...), target_lang: str = Form(...)):
translation = translate(model_nllb, tokenizer_nllb, text, target_lang)
return {"translation": translation}
@app.post("/inference/")
def inference_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...)):
# Save the uploaded file
file_location = f"/tmp/{original_path.filename}"
with open(file_location, "wb") as file:
file.write(original_path.file.read())
output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav"
torch.cuda.empty_cache()
generated_audio = just_inference(model_xtts, file_location, output_dir, text, lang)
return {"path_to_save": output_dir}
@app.post("/process-audio/")
async def process_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...), target_lang: str = Form(...)):
print(f"original_path: {original_path.filename}")
print(f"text: {text}")
print(f"lang: {lang}")
print(f"target_lang: {target_lang}")
# Validate target language
if target_lang not in ["es", "en"]: # Use 'es' and 'en' to match the example values
print("Unsupported language")
raise HTTPException(status_code=400, detail="Unsupported language. Use 'spanish' or 'english'.")
try:
# Translate the text first
translated_text = translate(model_nllb, tokenizer_nllb, text, target_lang)
print(f"translated_text: {translated_text}")
# Save the uploaded file
file_location = f"/tmp/{original_path.filename}"
with open(file_location, "wb") as file:
file.write(original_path.file.read())
output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav"
torch.cuda.empty_cache()
generated_audio = just_inference(model_xtts, file_location, output_dir, translated_text, target_lang)
return JSONResponse(content={"audio_path": output_dir, "translation": translated_text})
except Exception as e:
print(f"Error during processing: {e}")
raise HTTPException(status_code=500, detail="Error during processing")
@app.get("/download-audio/")
def download_audio(file_path: str):
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path, media_type='audio/wav', filename=os.path.basename(file_path))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)