import torch import gradio as gr from sgmse.model import ScoreModel # Load your model model_path = "https://huggingface.co./sp-uhh/speech-enhancement-sgmse/resolve/main/pretrained_checkpoints/speech_enhancement/train_vb_29nqe0uh_epoch%3D115.ckpt" #model = SGMSE() # Initialize your model class model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) #model.eval() # Set the model to evaluation mode def enhance_audio(input_audio): import torchaudio # Load the input audio file waveform, sample_rate = torchaudio.load(input_audio) with torch.no_grad(): enhanced_waveform = model(waveform) output_path = "enhanced_audio.wav" torchaudio.save(output_path, enhanced_waveform.cpu(), sample_rate) return output_path # Create the Gradio interface iface = gr.Interface( fn=enhance_audio, inputs=gr.Audio(source="upload", type="filepath"), outputs=gr.Audio(type="file"), title="Speech Enhancement Model", description="Upload a noisy audio file to enhance it using the SGMSE model." ) if __name__ == "__main__": iface.launch()