from fastapi import FastAPI, HTTPException, UploadFile, File from pydantic import BaseModel from aitextgen import aitextgen from sklearn.datasets import fetch_20newsgroups import nltk import spacy from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor from transformers import TTSModel, TTSProcessor from audiocraft.models import MusicGen from diffusers import StableDiffusionPipeline import os from typing import List # Descargar nltk y cargar spacy nltk.download('punkt') nltk.download('stopwords') spacy_model = spacy.load('en_core_web_sm') app = FastAPI() # Variables globales para almacenar los modelos global aitextgen_model, hf_model, musicgen_model, image_generation_model, whisper_model, whisper_processor, tts_model, tts_processor, newsgroups aitextgen_model = None hf_model = None musicgen_model = None image_generation_model = None whisper_model = None whisper_processor = None tts_model = None tts_processor = None newsgroups = None # Funciones para cargar los modelos solo una vez def load_aitextgen_model(): global aitextgen_model if aitextgen_model is None: aitextgen_model = aitextgen() return aitextgen_model def load_hf_model(): global hf_model if hf_model is None: hf_model = pipeline('text-generation', model='gpt2') return hf_model def load_musicgen_model(): global musicgen_model if musicgen_model is None: musicgen_model = MusicGen.get_pretrained('small') return musicgen_model def load_image_generation_model(): global image_generation_model if image_generation_model is None: image_generation_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") return image_generation_model def load_whisper_model(): global whisper_model, whisper_processor if whisper_model is None: whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") return whisper_model, whisper_processor def load_tts_model(): global tts_model, tts_processor if tts_model is None: tts_model = TTSModel.from_pretrained("facebook/tts_transformer-tts") tts_processor = TTSProcessor.from_pretrained("facebook/tts_transformer-tts") return tts_model, tts_processor def load_newsgroups(): global newsgroups if newsgroups is None: newsgroups = fetch_20newsgroups(subset='all').data return newsgroups class TextRequest(BaseModel): prompt: str max_length: int = 50 class MusicRequest(BaseModel): prompt: str duration: float = 10.0 class ImageRequest(BaseModel): prompt: str height: int = 512 width: int = 512 class TTSRequest(BaseModel): text: str @app.get("/") def read_root(): return {"message": "Welcome to the Text, Music Generation, Image Generation, Whisper, and TTS API!"} @app.post("/generate/") def generate_text(request: TextRequest): aitextgen_model = load_aitextgen_model() generated_text = aitextgen_model.generate(prompt=request.prompt, max_length=request.max_length) return {"generated_text": generated_text} @app.post("/hf_generate/") def hf_generate_text(request: TextRequest): hf_model = load_hf_model() generated_text = hf_model(request.prompt, max_length=request.max_length) return {"generated_text": generated_text[0]['generated_text']} @app.post("/music/") def generate_music(request: MusicRequest): musicgen_model = load_musicgen_model() audio = musicgen_model.generate([request.prompt], durations=[request.duration]) musicgen_model.save_wav(audio[0], 'generated_music.wav') return {"message": "Music generated successfully", "audio_file": "generated_music.wav"} @app.post("/generate_image/") def generate_image(request: ImageRequest): image_generation_model = load_image_generation_model() image = image_generation_model(request.prompt, height=request.height, width=request.width).images[0] image_path = "generated_image.png" image.save(image_path) return {"message": "Image generated successfully", "image_file": "generated_image.png"} @app.post("/transcribe/") async def transcribe_audio(file: UploadFile = File(...)): whisper_model, whisper_processor = load_whisper_model() audio_input = await file.read() audio_input = whisper_processor(audio_input, return_tensors="pt").input_features with torch.no_grad(): predicted_ids = whisper_model.generate(audio_input) transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return {"transcription": transcription} @app.post("/tts/") def text_to_speech(request: TTSRequest): tts_model, tts_processor = load_tts_model() audio = tts_model.generate(request.text) audio_path = "generated_speech.wav" tts_model.save_wav(audio, audio_path) return {"message": "Speech generated successfully", "audio_file": "generated_speech.wav"} @app.get("/newsgroups/") def get_newsgroups(): newsgroups_data = load_newsgroups() return {"newsgroups": newsgroups_data[:5]} @app.post("/process/") def process_text(text: str): tokens = nltk.word_tokenize(text) doc = spacy_model(text) return { "tokens": tokens, "entities": [(ent.text, ent.label_) for ent in doc.ents] }