Shokoufehhh commited on
Commit
dd32858
·
verified ·
1 Parent(s): f61386b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -3,19 +3,33 @@ import torchaudio
3
  from sgmse.model import ScoreModel
4
  import gradio as gr
5
  from sgmse.util.other import pad_spec
 
 
 
 
 
 
 
 
 
 
 
 
6
  # Load the pre-trained model
7
  model = ScoreModel.load_from_checkpoint("https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt")
 
8
  def enhance_speech(audio_file):
9
  # Load and process the audio file
10
  y, sr = torchaudio.load(audio_file)
11
- T_orig = y.size(1)
12
- # Normalize
 
13
  norm_factor = y.abs().max()
14
  y = y / norm_factor
15
 
16
  # Prepare DNN input
17
  Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
18
- Y = pad_spec(Y, mode=pad_mode)
19
 
20
  # Reverse sampling
21
  sampler = model.get_pc_sampler(
@@ -25,7 +39,8 @@ T_orig = y.size(1)
25
 
26
  # Backward transform in time domain
27
  x_hat = model.to_audio(sample.squeeze(), T_orig)
28
- # Renormalize
 
29
  x_hat = x_hat * norm_factor
30
 
31
  # Save the enhanced audio
@@ -33,12 +48,13 @@ T_orig = y.size(1)
33
  torchaudio.save(output_file, x_hat.cpu().numpy(), sr)
34
 
35
  return output_file
 
36
  # Gradio interface setup
37
  inputs = gr.Audio(label="Input Audio", type="filepath")
38
  outputs = gr.Audio(label="Output Audio", type="filepath")
39
  title = "Speech Enhancement using SGMSE"
40
  description = "This Gradio demo uses the SGMSE model for speech enhancement. Upload your audio file to enhance it."
41
  article = "<p style='text-align: center'><a href='https://huggingface.co/SP-UHH/speech-enhancement-sgmse' target='_blank'>Model Card</a></p>"
42
- # Launch without share=True (as it's not supported on Hugging Face Spaces)
43
- gr.Interface(fn=enhance_speech, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch(
44
 
 
 
 
3
  from sgmse.model import ScoreModel
4
  import gradio as gr
5
  from sgmse.util.other import pad_spec
6
+
7
+ # Define the necessary arguments
8
+ class Args:
9
+ device = 'cpu' # or 'cuda' if GPU is available and enabled in the environment
10
+ corrector = 'langevin' # Define your corrector method
11
+ N = 50 # Example value for number of steps
12
+ corrector_steps = 1 # Number of steps for the corrector
13
+ snr = 0.1 # Signal-to-noise ratio value for the corrector
14
+ pad_mode = 'reflect' # Pad mode for spectrogram padding
15
+
16
+ args = Args()
17
+
18
  # Load the pre-trained model
19
  model = ScoreModel.load_from_checkpoint("https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/train_vb_29nqe0uh_epoch%3D115.ckpt")
20
+
21
  def enhance_speech(audio_file):
22
  # Load and process the audio file
23
  y, sr = torchaudio.load(audio_file)
24
+ T_orig = y.size(1)
25
+
26
+ # Normalize
27
  norm_factor = y.abs().max()
28
  y = y / norm_factor
29
 
30
  # Prepare DNN input
31
  Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
32
+ Y = pad_spec(Y, mode=args.pad_mode)
33
 
34
  # Reverse sampling
35
  sampler = model.get_pc_sampler(
 
39
 
40
  # Backward transform in time domain
41
  x_hat = model.to_audio(sample.squeeze(), T_orig)
42
+
43
+ # Renormalize
44
  x_hat = x_hat * norm_factor
45
 
46
  # Save the enhanced audio
 
48
  torchaudio.save(output_file, x_hat.cpu().numpy(), sr)
49
 
50
  return output_file
51
+
52
  # Gradio interface setup
53
  inputs = gr.Audio(label="Input Audio", type="filepath")
54
  outputs = gr.Audio(label="Output Audio", type="filepath")
55
  title = "Speech Enhancement using SGMSE"
56
  description = "This Gradio demo uses the SGMSE model for speech enhancement. Upload your audio file to enhance it."
57
  article = "<p style='text-align: center'><a href='https://huggingface.co/SP-UHH/speech-enhancement-sgmse' target='_blank'>Model Card</a></p>"
 
 
58
 
59
+ # Launch without share=True (as it's not supported on Hugging Face Spaces)
60
+ gr.Interface(fn=enhance_speech, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()