Spaces:
Sleeping
Sleeping
import fastapi | |
import uvicorn | |
from fastapi import File, UploadFile, Form, HTTPException | |
from fastapi.responses import JSONResponse, FileResponse | |
from load_models import get_nllb_model_and_tokenizer, get_xtts_model | |
from inference_functions import translate, just_inference | |
import os | |
import torch | |
# Set GPU memory fraction | |
torch.cuda.set_per_process_memory_fraction(0.75, 0) | |
# Load models | |
model_nllb, tokenizer_nllb = get_nllb_model_and_tokenizer() | |
model_xtts = get_xtts_model() | |
app = fastapi.FastAPI() | |
def health_check(): | |
return {"status": "ok"} | |
def translate_text(text: str = Form(...), target_lang: str = Form(...)): | |
translation = translate(model_nllb, tokenizer_nllb, text, target_lang) | |
return {"translation": translation} | |
def inference_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...)): | |
# Save the uploaded file | |
file_location = f"/tmp/{original_path.filename}" | |
with open(file_location, "wb") as file: | |
file.write(original_path.file.read()) | |
output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav" | |
torch.cuda.empty_cache() | |
generated_audio = just_inference(model_xtts, file_location, output_dir, text, lang) | |
return {"path_to_save": output_dir} | |
async def process_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...), target_lang: str = Form(...)): | |
print(f"original_path: {original_path.filename}") | |
print(f"text: {text}") | |
print(f"lang: {lang}") | |
print(f"target_lang: {target_lang}") | |
# Validate target language | |
if target_lang not in ["es", "en"]: # Use 'es' and 'en' to match the example values | |
print("Unsupported language") | |
raise HTTPException(status_code=400, detail="Unsupported language. Use 'spanish' or 'english'.") | |
try: | |
# Translate the text first | |
translated_text = translate(model_nllb, tokenizer_nllb, text, target_lang) | |
print(f"translated_text: {translated_text}") | |
# Save the uploaded file | |
file_location = f"/tmp/{original_path.filename}" | |
with open(file_location, "wb") as file: | |
file.write(original_path.file.read()) | |
output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav" | |
torch.cuda.empty_cache() | |
generated_audio = just_inference(model_xtts, file_location, output_dir, translated_text, target_lang) | |
return JSONResponse(content={"audio_path": output_dir, "translation": translated_text}) | |
except Exception as e: | |
print(f"Error during processing: {e}") | |
raise HTTPException(status_code=500, detail="Error during processing") | |
def download_audio(file_path: str): | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=404, detail="File not found") | |
return FileResponse(file_path, media_type='audio/wav', filename=os.path.basename(file_path)) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |