jhj0517 commited on
Commit
c5dc9f3
Β·
2 Parent(s): 75fca27 8615b10

Merge pull request #181 from jhj0517/feature/integrate-whisperx

Browse files
app.py CHANGED
@@ -1,15 +1,14 @@
1
- import gradio as gr
2
  import os
3
  import argparse
4
 
5
- from modules.whisper_Inference import WhisperInference
6
- from modules.faster_whisper_inference import FasterWhisperInference
7
- from modules.insanely_fast_whisper_inference import InsanelyFastWhisperInference
8
- from modules.nllb_inference import NLLBInference
9
  from ui.htmls import *
10
- from modules.youtube_manager import get_ytmetas
11
- from modules.deepl_api import DeepLAPI
12
- from modules.whisper_parameter import *
13
 
14
 
15
  class App:
@@ -28,28 +27,35 @@ class App:
28
  )
29
 
30
  def init_whisper(self):
 
 
 
31
  whisper_type = self.args.whisper_type.lower().strip()
32
 
33
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
34
  whisper_inf = FasterWhisperInference(
35
  model_dir=self.args.faster_whisper_model_dir,
36
- output_dir=self.args.output_dir
 
37
  )
38
  elif whisper_type in ["whisper"]:
39
  whisper_inf = WhisperInference(
40
  model_dir=self.args.whisper_model_dir,
41
- output_dir=self.args.output_dir
 
42
  )
43
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
44
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
45
  whisper_inf = InsanelyFastWhisperInference(
46
  model_dir=self.args.insanely_fast_whisper_model_dir,
47
- output_dir=self.args.output_dir
 
48
  )
49
  else:
50
  whisper_inf = FasterWhisperInference(
51
  model_dir=self.args.faster_whisper_model_dir,
52
- output_dir=self.args.output_dir
 
53
  )
54
  return whisper_inf
55
 
@@ -87,7 +93,7 @@ class App:
87
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
88
  with gr.Row():
89
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
90
- with gr.Accordion("Advanced_Parameters", open=False):
91
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
92
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
93
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
@@ -98,14 +104,20 @@ class App:
98
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
99
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
100
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
101
- with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
102
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
103
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
104
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
105
  nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
106
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
107
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
108
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
 
 
 
 
 
 
109
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
110
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
111
  nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
@@ -138,10 +150,13 @@ class App:
138
  window_size_sample=nb_window_size_sample,
139
  speech_pad_ms=nb_speech_pad_ms,
140
  chunk_length_s=nb_chunk_length_s,
141
- batch_size=nb_batch_size)
 
 
 
142
 
143
  btn_run.click(fn=self.whisper_inf.transcribe_file,
144
- inputs=params + whisper_params.to_list(),
145
  outputs=[tb_indicator, files_subtitles])
146
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
147
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
@@ -166,7 +181,7 @@ class App:
166
  with gr.Row():
167
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
168
  interactive=True)
169
- with gr.Accordion("Advanced_Parameters", open=False):
170
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
171
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
172
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
@@ -177,14 +192,20 @@ class App:
177
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
178
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
179
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
180
- with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
181
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
182
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
183
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
184
  nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
185
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
186
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
187
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
 
 
 
 
 
 
188
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
189
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
190
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
@@ -218,10 +239,13 @@ class App:
218
  window_size_sample=nb_window_size_sample,
219
  speech_pad_ms=nb_speech_pad_ms,
220
  chunk_length_s=nb_chunk_length_s,
221
- batch_size=nb_batch_size)
 
 
 
222
 
223
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
224
- inputs=params + whisper_params.to_list(),
225
  outputs=[tb_indicator, files_subtitles])
226
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
227
  outputs=[img_thumbnail, tb_title, tb_description])
@@ -239,7 +263,7 @@ class App:
239
  dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
240
  with gr.Row():
241
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
242
- with gr.Accordion("Advanced_Parameters", open=False):
243
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
244
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
245
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
@@ -249,14 +273,22 @@ class App:
249
  cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
250
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
251
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
252
- with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
253
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
254
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
255
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
256
  nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
257
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
258
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
259
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
 
 
 
 
 
 
 
 
260
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
261
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
262
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
@@ -290,10 +322,13 @@ class App:
290
  window_size_sample=nb_window_size_sample,
291
  speech_pad_ms=nb_speech_pad_ms,
292
  chunk_length_s=nb_chunk_length_s,
293
- batch_size=nb_batch_size)
 
 
 
294
 
295
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
296
- inputs=params + whisper_params.to_list(),
297
  outputs=[tb_indicator, files_subtitles])
298
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
299
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
@@ -392,6 +427,7 @@ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=Tru
392
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
393
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
394
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
 
395
  parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
396
  parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
397
  _args = parser.parse_args()
 
 
1
  import os
2
  import argparse
3
 
4
+ from modules.whisper.whisper_Inference import WhisperInference
5
+ from modules.whisper.faster_whisper_inference import FasterWhisperInference
6
+ from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
7
+ from modules.translation.nllb_inference import NLLBInference
8
  from ui.htmls import *
9
+ from modules.utils.youtube_manager import get_ytmetas
10
+ from modules.translation.deepl_api import DeepLAPI
11
+ from modules.whisper.whisper_parameter import *
12
 
13
 
14
  class App:
 
27
  )
28
 
29
  def init_whisper(self):
30
+ # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
31
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
32
+
33
  whisper_type = self.args.whisper_type.lower().strip()
34
 
35
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
36
  whisper_inf = FasterWhisperInference(
37
  model_dir=self.args.faster_whisper_model_dir,
38
+ output_dir=self.args.output_dir,
39
+ args=self.args
40
  )
41
  elif whisper_type in ["whisper"]:
42
  whisper_inf = WhisperInference(
43
  model_dir=self.args.whisper_model_dir,
44
+ output_dir=self.args.output_dir,
45
+ args=self.args
46
  )
47
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
48
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
49
  whisper_inf = InsanelyFastWhisperInference(
50
  model_dir=self.args.insanely_fast_whisper_model_dir,
51
+ output_dir=self.args.output_dir,
52
+ args=self.args
53
  )
54
  else:
55
  whisper_inf = FasterWhisperInference(
56
  model_dir=self.args.faster_whisper_model_dir,
57
+ output_dir=self.args.output_dir,
58
+ args=self.args
59
  )
60
  return whisper_inf
61
 
 
93
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
94
  with gr.Row():
95
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
96
+ with gr.Accordion("Advanced Parameters", open=False):
97
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
98
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
99
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
104
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
105
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
106
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
107
+ with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
108
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
109
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
110
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
111
  nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
112
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
113
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
114
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
115
+ with gr.Accordion("Diarization", open=False):
116
+ cb_diarize = gr.Checkbox(label="Enable Diarization")
117
+ tb_hf_token = gr.Text(label="HuggingFace Token", value="",
118
+ info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
119
+ "To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
120
+ dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
121
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
122
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
123
  nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
 
150
  window_size_sample=nb_window_size_sample,
151
  speech_pad_ms=nb_speech_pad_ms,
152
  chunk_length_s=nb_chunk_length_s,
153
+ batch_size=nb_batch_size,
154
+ is_diarize=cb_diarize,
155
+ hf_token=tb_hf_token,
156
+ diarization_device=dd_diarization_device)
157
 
158
  btn_run.click(fn=self.whisper_inf.transcribe_file,
159
+ inputs=params + whisper_params.as_list(),
160
  outputs=[tb_indicator, files_subtitles])
161
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
162
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
 
181
  with gr.Row():
182
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
183
  interactive=True)
184
+ with gr.Accordion("Advanced Parameters", open=False):
185
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
186
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
187
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
192
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
193
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
194
  nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
195
+ with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
196
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
197
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
198
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
199
  nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
200
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
201
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
202
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
203
+ with gr.Accordion("Diarization", open=False):
204
+ cb_diarize = gr.Checkbox(label="Enable Diarization")
205
+ tb_hf_token = gr.Text(label="HuggingFace Token", value="",
206
+ info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
207
+ "To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
208
+ dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
209
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
210
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
211
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
 
239
  window_size_sample=nb_window_size_sample,
240
  speech_pad_ms=nb_speech_pad_ms,
241
  chunk_length_s=nb_chunk_length_s,
242
+ batch_size=nb_batch_size,
243
+ is_diarize=cb_diarize,
244
+ hf_token=tb_hf_token,
245
+ diarization_device=dd_diarization_device)
246
 
247
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
248
+ inputs=params + whisper_params.as_list(),
249
  outputs=[tb_indicator, files_subtitles])
250
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
251
  outputs=[img_thumbnail, tb_title, tb_description])
 
263
  dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
264
  with gr.Row():
265
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
266
+ with gr.Accordion("Advanced Parameters", open=False):
267
  nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
268
  nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
269
  nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
 
273
  cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
274
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
275
  sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
276
+ with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
277
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
278
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
279
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
280
  nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
281
  nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
282
  nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
283
  nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
284
+ with gr.Accordion("Diarization", open=False):
285
+ cb_diarize = gr.Checkbox(label="Enable Diarization")
286
+ tb_hf_token = gr.Text(label="HuggingFace Token", value="",
287
+ info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
288
+ "To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
289
+ dd_diarization_device = gr.Dropdown(label="Device",
290
+ choices=self.whisper_inf.diarizer.get_available_device(),
291
+ value=self.whisper_inf.diarizer.get_device())
292
  with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
293
  visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
294
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
 
322
  window_size_sample=nb_window_size_sample,
323
  speech_pad_ms=nb_speech_pad_ms,
324
  chunk_length_s=nb_chunk_length_s,
325
+ batch_size=nb_batch_size,
326
+ is_diarize=cb_diarize,
327
+ hf_token=tb_hf_token,
328
+ diarization_device=dd_diarization_device)
329
 
330
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
331
+ inputs=params + whisper_params.as_list(),
332
  outputs=[tb_indicator, files_subtitles])
333
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
334
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
 
427
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
428
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
429
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
430
+ parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"), help='Directory path of the diarization model')
431
  parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
432
  parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
433
  _args = parser.parse_args()
modules/diarize/__init__.py ADDED
File without changes
modules/diarize/audio_loader.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from functools import lru_cache
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def exact_div(x, y):
11
+ assert x % y == 0
12
+ return x // y
13
+
14
+ # hard-coded audio hyperparameters
15
+ SAMPLE_RATE = 16000
16
+ N_FFT = 400
17
+ HOP_LENGTH = 160
18
+ CHUNK_LENGTH = 30
19
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
20
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
21
+
22
+ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
23
+ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
24
+ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
25
+
26
+
27
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
28
+ """
29
+ Open an audio file and read as mono waveform, resampling as necessary
30
+
31
+ Parameters
32
+ ----------
33
+ file: str
34
+ The audio file to open
35
+
36
+ sr: int
37
+ The sample rate to resample the audio if necessary
38
+
39
+ Returns
40
+ -------
41
+ A NumPy array containing the audio waveform, in float32 dtype.
42
+ """
43
+ try:
44
+ # Launches a subprocess to decode audio while down-mixing and resampling as necessary.
45
+ # Requires the ffmpeg CLI to be installed.
46
+ cmd = [
47
+ "ffmpeg",
48
+ "-nostdin",
49
+ "-threads",
50
+ "0",
51
+ "-i",
52
+ file,
53
+ "-f",
54
+ "s16le",
55
+ "-ac",
56
+ "1",
57
+ "-acodec",
58
+ "pcm_s16le",
59
+ "-ar",
60
+ str(sr),
61
+ "-",
62
+ ]
63
+ out = subprocess.run(cmd, capture_output=True, check=True).stdout
64
+ except subprocess.CalledProcessError as e:
65
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
66
+
67
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
68
+
69
+
70
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
71
+ """
72
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
73
+ """
74
+ if torch.is_tensor(array):
75
+ if array.shape[axis] > length:
76
+ array = array.index_select(
77
+ dim=axis, index=torch.arange(length, device=array.device)
78
+ )
79
+
80
+ if array.shape[axis] < length:
81
+ pad_widths = [(0, 0)] * array.ndim
82
+ pad_widths[axis] = (0, length - array.shape[axis])
83
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
84
+ else:
85
+ if array.shape[axis] > length:
86
+ array = array.take(indices=range(length), axis=axis)
87
+
88
+ if array.shape[axis] < length:
89
+ pad_widths = [(0, 0)] * array.ndim
90
+ pad_widths[axis] = (0, length - array.shape[axis])
91
+ array = np.pad(array, pad_widths)
92
+
93
+ return array
94
+
95
+
96
+ @lru_cache(maxsize=None)
97
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
98
+ """
99
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
100
+ Allows decoupling librosa dependency; saved using:
101
+
102
+ np.savez_compressed(
103
+ "mel_filters.npz",
104
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
105
+ )
106
+ """
107
+ assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
108
+ with np.load(
109
+ os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
110
+ ) as f:
111
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
112
+
113
+
114
+ def log_mel_spectrogram(
115
+ audio: Union[str, np.ndarray, torch.Tensor],
116
+ n_mels: int,
117
+ padding: int = 0,
118
+ device: Optional[Union[str, torch.device]] = None,
119
+ ):
120
+ """
121
+ Compute the log-Mel spectrogram of
122
+
123
+ Parameters
124
+ ----------
125
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
126
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
127
+
128
+ n_mels: int
129
+ The number of Mel-frequency filters, only 80 is supported
130
+
131
+ padding: int
132
+ Number of zero samples to pad to the right
133
+
134
+ device: Optional[Union[str, torch.device]]
135
+ If given, the audio tensor is moved to this device before STFT
136
+
137
+ Returns
138
+ -------
139
+ torch.Tensor, shape = (80, n_frames)
140
+ A Tensor that contains the Mel spectrogram
141
+ """
142
+ if not torch.is_tensor(audio):
143
+ if isinstance(audio, str):
144
+ audio = load_audio(audio)
145
+ audio = torch.from_numpy(audio)
146
+
147
+ if device is not None:
148
+ audio = audio.to(device)
149
+ if padding > 0:
150
+ audio = F.pad(audio, (0, padding))
151
+ window = torch.hann_window(N_FFT).to(audio.device)
152
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
153
+ magnitudes = stft[..., :-1].abs() ** 2
154
+
155
+ filters = mel_filters(audio.device, n_mels)
156
+ mel_spec = filters @ magnitudes
157
+
158
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
159
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
160
+ log_spec = (log_spec + 4.0) / 4.0
161
+ return log_spec
modules/diarize/diarize_pipeline.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import os
4
+ from pyannote.audio import Pipeline
5
+ from typing import Optional, Union
6
+ import torch
7
+
8
+ from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
9
+
10
+
11
+ class DiarizationPipeline:
12
+ def __init__(
13
+ self,
14
+ model_name="pyannote/speaker-diarization-3.1",
15
+ cache_dir: str = os.path.join("models", "Diarization"),
16
+ use_auth_token=None,
17
+ device: Optional[Union[str, torch.device]] = "cpu",
18
+ ):
19
+ if isinstance(device, str):
20
+ device = torch.device(device)
21
+ self.model = Pipeline.from_pretrained(
22
+ model_name,
23
+ use_auth_token=use_auth_token,
24
+ cache_dir=cache_dir
25
+ ).to(device)
26
+
27
+ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
28
+ if isinstance(audio, str):
29
+ audio = load_audio(audio)
30
+ audio_data = {
31
+ 'waveform': torch.from_numpy(audio[None, :]),
32
+ 'sample_rate': SAMPLE_RATE
33
+ }
34
+ segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
35
+ diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
36
+ diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
37
+ diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
38
+ return diarize_df
39
+
40
+
41
+ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
42
+ transcript_segments = transcript_result["segments"]
43
+ for seg in transcript_segments:
44
+ # assign speaker to segment (if any)
45
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
46
+ seg['start'])
47
+ diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
48
+
49
+ intersected = diarize_df[diarize_df["intersection"] > 0]
50
+
51
+ speaker = None
52
+ if len(intersected) > 0:
53
+ # Choosing most strong intersection
54
+ speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
55
+ elif fill_nearest:
56
+ # Otherwise choosing closest
57
+ speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
58
+
59
+ if speaker is not None:
60
+ seg["speaker"] = speaker
61
+
62
+ # assign speaker to words
63
+ if 'words' in seg:
64
+ for word in seg['words']:
65
+ if 'start' in word:
66
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
67
+ diarize_df['start'], word['start'])
68
+ diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
69
+ word['start'])
70
+
71
+ intersected = diarize_df[diarize_df["intersection"] > 0]
72
+
73
+ word_speaker = None
74
+ if len(intersected) > 0:
75
+ # Choosing most strong intersection
76
+ word_speaker = \
77
+ intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
78
+ elif fill_nearest:
79
+ # Otherwise choosing closest
80
+ word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
81
+
82
+ if word_speaker is not None:
83
+ word["speaker"] = word_speaker
84
+
85
+ return transcript_result
86
+
87
+
88
+ class Segment:
89
+ def __init__(self, start, end, speaker=None):
90
+ self.start = start
91
+ self.end = end
92
+ self.speaker = speaker
modules/diarize/diarizer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List
4
+ import time
5
+ import logging
6
+
7
+ from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
8
+ from modules.diarize.audio_loader import load_audio
9
+
10
+ class Diarizer:
11
+ def __init__(self,
12
+ model_dir: str = os.path.join("models", "Diarization")
13
+ ):
14
+ self.device = self.get_device()
15
+ self.available_device = self.get_available_device()
16
+ self.compute_type = "float16"
17
+ self.model_dir = model_dir
18
+ os.makedirs(self.model_dir, exist_ok=True)
19
+ self.pipe = None
20
+
21
+ def run(self,
22
+ audio: str,
23
+ transcribed_result: List[dict],
24
+ use_auth_token: str,
25
+ device: str
26
+ ):
27
+ """
28
+ Diarize transcribed result as a post-processing
29
+
30
+ Parameters
31
+ ----------
32
+ audio: Union[str, BinaryIO, np.ndarray]
33
+ Audio input. This can be file path or binary type.
34
+ transcribed_result: List[dict]
35
+ transcribed result through whisper.
36
+ use_auth_token: str
37
+ Huggingface token with READ permission. This is only needed the first time you download the model.
38
+ You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
39
+ device: str
40
+ Device for diarization.
41
+
42
+ Returns
43
+ ----------
44
+ segments_result: List[dict]
45
+ list of dicts that includes start, end timestamps and transcribed text
46
+ elapsed_time: float
47
+ elapsed time for running
48
+ """
49
+ start_time = time.time()
50
+
51
+ if (device != self.device
52
+ or self.pipe is None):
53
+ self.update_pipe(
54
+ device=device,
55
+ use_auth_token=use_auth_token
56
+ )
57
+
58
+ audio = load_audio(audio)
59
+
60
+ diarization_segments = self.pipe(audio)
61
+ diarized_result = assign_word_speakers(
62
+ diarization_segments,
63
+ {"segments": transcribed_result}
64
+ )
65
+
66
+ for segment in diarized_result["segments"]:
67
+ speaker = "None"
68
+ if "speaker" in segment:
69
+ speaker = segment["speaker"]
70
+ segment["text"] = speaker + "|" + segment["text"][1:]
71
+
72
+ elapsed_time = time.time() - start_time
73
+ return diarized_result["segments"], elapsed_time
74
+
75
+ def update_pipe(self,
76
+ use_auth_token: str,
77
+ device: str
78
+ ):
79
+ """
80
+ Set pipeline for diarization
81
+
82
+ Parameters
83
+ ----------
84
+ use_auth_token: str
85
+ Huggingface token with READ permission. This is only needed the first time you download the model.
86
+ You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
87
+ device: str
88
+ Device for diarization.
89
+ """
90
+
91
+ os.makedirs(self.model_dir, exist_ok=True)
92
+
93
+ if (not os.listdir(self.model_dir) and
94
+ not use_auth_token):
95
+ print(
96
+ "\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
97
+ "Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
98
+ )
99
+ return
100
+
101
+ logger = logging.getLogger("speechbrain.utils.train_logger")
102
+ # Disable redundant torchvision warning message
103
+ logger.disabled = True
104
+ self.pipe = DiarizationPipeline(
105
+ use_auth_token=use_auth_token,
106
+ device=device,
107
+ cache_dir=self.model_dir
108
+ )
109
+ logger.disabled = False
110
+
111
+ @staticmethod
112
+ def get_device():
113
+ if torch.cuda.is_available():
114
+ return "cuda"
115
+ elif torch.backends.mps.is_available():
116
+ return "mps"
117
+ else:
118
+ return "cpu"
119
+
120
+ @staticmethod
121
+ def get_available_device():
122
+ devices = ["cpu"]
123
+ if torch.cuda.is_available():
124
+ devices.append("cuda")
125
+ elif torch.backends.mps.is_available():
126
+ devices.append("mps")
127
+ return devices
modules/translation/__init__.py ADDED
File without changes
modules/{deepl_api.py β†’ translation/deepl_api.py} RENAMED
@@ -4,7 +4,7 @@ import os
4
  from datetime import datetime
5
  import gradio as gr
6
 
7
- from modules.subtitle_manager import *
8
 
9
  """
10
  This is written with reference to the DeepL API documentation.
@@ -144,7 +144,7 @@ class DeepLAPI:
144
  timestamp = datetime.now().strftime("%m%d%H%M%S")
145
 
146
  file_name = file_name[:-9]
147
- output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.srt")
148
  write_file(subtitle, output_path)
149
 
150
  elif file_ext == ".vtt":
@@ -164,7 +164,7 @@ class DeepLAPI:
164
  timestamp = datetime.now().strftime("%m%d%H%M%S")
165
 
166
  file_name = file_name[:-9]
167
- output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt")
168
 
169
  write_file(subtitle, output_path)
170
 
 
4
  from datetime import datetime
5
  import gradio as gr
6
 
7
+ from modules.utils.subtitle_manager import *
8
 
9
  """
10
  This is written with reference to the DeepL API documentation.
 
144
  timestamp = datetime.now().strftime("%m%d%H%M%S")
145
 
146
  file_name = file_name[:-9]
147
+ output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.srt")
148
  write_file(subtitle, output_path)
149
 
150
  elif file_ext == ".vtt":
 
164
  timestamp = datetime.now().strftime("%m%d%H%M%S")
165
 
166
  file_name = file_name[:-9]
167
+ output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.vtt")
168
 
169
  write_file(subtitle, output_path)
170
 
modules/{nllb_inference.py β†’ translation/nllb_inference.py} RENAMED
@@ -2,7 +2,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
  import os
4
 
5
- from modules.translation_base import TranslationBase
6
 
7
 
8
  class NLLBInference(TranslationBase):
 
2
  import gradio as gr
3
  import os
4
 
5
+ from modules.translation.translation_base import TranslationBase
6
 
7
 
8
  class NLLBInference(TranslationBase):
modules/{translation_base.py β†’ translation/translation_base.py} RENAMED
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
5
  from typing import List
6
  from datetime import datetime
7
 
8
- from modules.whisper_parameter import *
9
- from modules.subtitle_manager import *
10
 
11
 
12
  class TranslationBase(ABC):
@@ -90,9 +90,9 @@ class TranslationBase(ABC):
90
 
91
  timestamp = datetime.now().strftime("%m%d%H%M%S")
92
  if add_timestamp:
93
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
94
  else:
95
- output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
96
 
97
  elif file_ext == ".vtt":
98
  parsed_dicts = parse_vtt(file_path=file_path)
@@ -105,9 +105,9 @@ class TranslationBase(ABC):
105
 
106
  timestamp = datetime.now().strftime("%m%d%H%M%S")
107
  if add_timestamp:
108
- output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt")
109
  else:
110
- output_path = os.path.join(self.output_dir, "translations", f"{file_name}.vtt")
111
 
112
  write_file(subtitle, output_path)
113
  files_info[file_name] = subtitle
 
5
  from typing import List
6
  from datetime import datetime
7
 
8
+ from modules.whisper.whisper_parameter import *
9
+ from modules.utils.subtitle_manager import *
10
 
11
 
12
  class TranslationBase(ABC):
 
90
 
91
  timestamp = datetime.now().strftime("%m%d%H%M%S")
92
  if add_timestamp:
93
+ output_path = os.path.join("outputs", "", f"{file_name}-{timestamp}.srt")
94
  else:
95
+ output_path = os.path.join("outputs", "", f"{file_name}.srt")
96
 
97
  elif file_ext == ".vtt":
98
  parsed_dicts = parse_vtt(file_path=file_path)
 
105
 
106
  timestamp = datetime.now().strftime("%m%d%H%M%S")
107
  if add_timestamp:
108
+ output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.vtt")
109
  else:
110
+ output_path = os.path.join(self.output_dir, "", f"{file_name}.vtt")
111
 
112
  write_file(subtitle, output_path)
113
  files_info[file_name] = subtitle
modules/utils/__init__.py ADDED
File without changes
modules/{subtitle_manager.py β†’ utils/subtitle_manager.py} RENAMED
File without changes
modules/{youtube_manager.py β†’ utils/youtube_manager.py} RENAMED
File without changes
modules/whisper/__init__.py ADDED
File without changes
modules/{faster_whisper_inference.py β†’ whisper/faster_whisper_inference.py} RENAMED
@@ -2,28 +2,27 @@ import os
2
  import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
5
-
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
8
  import ctranslate2
9
  import whisper
10
  import gradio as gr
 
11
 
12
- from modules.whisper_parameter import *
13
- from modules.whisper_base import WhisperBase
14
-
15
- # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
16
- os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
17
 
18
 
19
  class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
  model_dir: str,
22
- output_dir: str
 
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
- output_dir=output_dir
 
27
  )
28
  self.model_paths = self.get_model_paths()
29
  self.available_models = self.model_paths.keys()
@@ -45,7 +44,7 @@ class FasterWhisperInference(WhisperBase):
45
  progress: gr.Progress
46
  Indicator to show progress directly in gradio.
47
  *whisper_params: tuple
48
- Gradio components related to Whisper. see whisper_data_class.py for details.
49
 
50
  Returns
51
  ----------
@@ -56,7 +55,7 @@ class FasterWhisperInference(WhisperBase):
56
  """
57
  start_time = time.time()
58
 
59
- params = WhisperParameters.post_process(*whisper_params)
60
 
61
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
62
  self.update_model(params.model_size, params.compute_type, progress)
 
2
  import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
 
5
  import faster_whisper
6
  from faster_whisper.vad import VadOptions
7
  import ctranslate2
8
  import whisper
9
  import gradio as gr
10
+ from argparse import Namespace
11
 
12
+ from modules.whisper.whisper_parameter import *
13
+ from modules.whisper.whisper_base import WhisperBase
 
 
 
14
 
15
 
16
  class FasterWhisperInference(WhisperBase):
17
  def __init__(self,
18
  model_dir: str,
19
+ output_dir: str,
20
+ args: Namespace
21
  ):
22
  super().__init__(
23
  model_dir=model_dir,
24
+ output_dir=output_dir,
25
+ args=args
26
  )
27
  self.model_paths = self.get_model_paths()
28
  self.available_models = self.model_paths.keys()
 
44
  progress: gr.Progress
45
  Indicator to show progress directly in gradio.
46
  *whisper_params: tuple
47
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
48
 
49
  Returns
50
  ----------
 
55
  """
56
  start_time = time.time()
57
 
58
+ params = WhisperParameters.as_value(*whisper_params)
59
 
60
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
61
  self.update_model(params.model_size, params.compute_type, progress)
modules/{insanely_fast_whisper_inference.py β†’ whisper/insanely_fast_whisper_inference.py} RENAMED
@@ -9,19 +9,22 @@ import gradio as gr
9
  from huggingface_hub import hf_hub_download
10
  import whisper
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
 
12
 
13
- from modules.whisper_parameter import *
14
- from modules.whisper_base import WhisperBase
15
 
16
 
17
  class InsanelyFastWhisperInference(WhisperBase):
18
  def __init__(self,
19
  model_dir: str,
20
- output_dir: str
 
21
  ):
22
  super().__init__(
23
  model_dir=model_dir,
24
- output_dir=output_dir
 
25
  )
26
  openai_models = whisper.available_models()
27
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
@@ -43,7 +46,7 @@ class InsanelyFastWhisperInference(WhisperBase):
43
  progress: gr.Progress
44
  Indicator to show progress directly in gradio.
45
  *whisper_params: tuple
46
- Gradio components related to Whisper. see whisper_data_class.py for details.
47
 
48
  Returns
49
  ----------
@@ -53,7 +56,7 @@ class InsanelyFastWhisperInference(WhisperBase):
53
  elapsed time for transcription
54
  """
55
  start_time = time.time()
56
- params = WhisperParameters.post_process(*whisper_params)
57
 
58
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
59
  self.update_model(params.model_size, params.compute_type, progress)
 
9
  from huggingface_hub import hf_hub_download
10
  import whisper
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
+ from argparse import Namespace
13
 
14
+ from modules.whisper.whisper_parameter import *
15
+ from modules.whisper.whisper_base import WhisperBase
16
 
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
  model_dir: str,
21
+ output_dir: str,
22
+ args: Namespace
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
+ output_dir=output_dir,
27
+ args=args
28
  )
29
  openai_models = whisper.available_models()
30
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
 
46
  progress: gr.Progress
47
  Indicator to show progress directly in gradio.
48
  *whisper_params: tuple
49
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
50
 
51
  Returns
52
  ----------
 
56
  elapsed time for transcription
57
  """
58
  start_time = time.time()
59
+ params = WhisperParameters.as_value(*whisper_params)
60
 
61
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
62
  self.update_model(params.model_size, params.compute_type, progress)
modules/{whisper_Inference.py β†’ whisper/whisper_Inference.py} RENAMED
@@ -1,23 +1,25 @@
1
  import whisper
2
  import gradio as gr
3
  import time
4
- import os
5
  from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
7
  import torch
 
8
 
9
- from modules.whisper_base import WhisperBase
10
- from modules.whisper_parameter import *
11
 
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
  model_dir: str,
16
- output_dir: str
 
17
  ):
18
  super().__init__(
19
  model_dir=model_dir,
20
- output_dir=output_dir
 
21
  )
22
 
23
  def transcribe(self,
@@ -35,7 +37,7 @@ class WhisperInference(WhisperBase):
35
  progress: gr.Progress
36
  Indicator to show progress directly in gradio.
37
  *whisper_params: tuple
38
- Gradio components related to Whisper. see whisper_data_class.py for details.
39
 
40
  Returns
41
  ----------
@@ -45,7 +47,7 @@ class WhisperInference(WhisperBase):
45
  elapsed time for transcription
46
  """
47
  start_time = time.time()
48
- params = WhisperParameters.post_process(*whisper_params)
49
 
50
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
51
  self.update_model(params.model_size, params.compute_type, progress)
 
1
  import whisper
2
  import gradio as gr
3
  import time
 
4
  from typing import BinaryIO, Union, Tuple, List
5
  import numpy as np
6
  import torch
7
+ from argparse import Namespace
8
 
9
+ from modules.whisper.whisper_base import WhisperBase
10
+ from modules.whisper.whisper_parameter import *
11
 
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
  model_dir: str,
16
+ output_dir: str,
17
+ args: Namespace
18
  ):
19
  super().__init__(
20
  model_dir=model_dir,
21
+ output_dir=output_dir,
22
+ args=args
23
  )
24
 
25
  def transcribe(self,
 
37
  progress: gr.Progress
38
  Indicator to show progress directly in gradio.
39
  *whisper_params: tuple
40
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
41
 
42
  Returns
43
  ----------
 
47
  elapsed time for transcription
48
  """
49
  start_time = time.time()
50
+ params = WhisperParameters.as_value(*whisper_params)
51
 
52
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
53
  self.update_model(params.model_size, params.compute_type, progress)
modules/{whisper_base.py β†’ whisper/whisper_base.py} RENAMED
@@ -1,22 +1,24 @@
1
  import os
2
  import torch
3
- from typing import List
4
  import whisper
5
  import gradio as gr
6
  from abc import ABC, abstractmethod
7
  from typing import BinaryIO, Union, Tuple, List
8
  import numpy as np
9
  from datetime import datetime
 
10
 
11
- from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
- from modules.youtube_manager import get_ytdata, get_ytaudio
13
- from modules.whisper_parameter import *
 
14
 
15
 
16
  class WhisperBase(ABC):
17
  def __init__(self,
18
  model_dir: str,
19
- output_dir: str
 
20
  ):
21
  self.model = None
22
  self.current_model_size = None
@@ -30,6 +32,9 @@ class WhisperBase(ABC):
30
  self.device = self.get_device()
31
  self.available_compute_types = ["float16", "float32"]
32
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
 
 
 
33
 
34
  @abstractmethod
35
  def transcribe(self,
@@ -47,6 +52,55 @@ class WhisperBase(ABC):
47
  ):
48
  pass
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def transcribe_file(self,
51
  files: list,
52
  file_format: str,
@@ -68,7 +122,7 @@ class WhisperBase(ABC):
68
  progress: gr.Progress
69
  Indicator to show progress directly in gradio.
70
  *whisper_params: tuple
71
- Gradio components related to Whisper. see whisper_data_class.py for details.
72
 
73
  Returns
74
  ----------
@@ -80,7 +134,7 @@ class WhisperBase(ABC):
80
  try:
81
  files_info = {}
82
  for file in files:
83
- transcribed_segments, time_for_task = self.transcribe(
84
  file.name,
85
  progress,
86
  *whisper_params,
@@ -135,7 +189,7 @@ class WhisperBase(ABC):
135
  progress: gr.Progress
136
  Indicator to show progress directly in gradio.
137
  *whisper_params: tuple
138
- Gradio components related to Whisper. see whisper_data_class.py for details.
139
 
140
  Returns
141
  ----------
@@ -146,7 +200,7 @@ class WhisperBase(ABC):
146
  """
147
  try:
148
  progress(0, desc="Loading Audio..")
149
- transcribed_segments, time_for_task = self.transcribe(
150
  mic_audio,
151
  progress,
152
  *whisper_params,
@@ -190,7 +244,7 @@ class WhisperBase(ABC):
190
  progress: gr.Progress
191
  Indicator to show progress directly in gradio.
192
  *whisper_params: tuple
193
- Gradio components related to Whisper. see whisper_data_class.py for details.
194
 
195
  Returns
196
  ----------
@@ -204,7 +258,7 @@ class WhisperBase(ABC):
204
  yt = get_ytdata(youtube_link)
205
  audio = get_ytaudio(yt)
206
 
207
- transcribed_segments, time_for_task = self.transcribe(
208
  audio,
209
  progress,
210
  *whisper_params,
 
1
  import os
2
  import torch
 
3
  import whisper
4
  import gradio as gr
5
  from abc import ABC, abstractmethod
6
  from typing import BinaryIO, Union, Tuple, List
7
  import numpy as np
8
  from datetime import datetime
9
+ from argparse import Namespace
10
 
11
+ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
+ from modules.utils.youtube_manager import get_ytdata, get_ytaudio
13
+ from modules.whisper.whisper_parameter import *
14
+ from modules.diarize.diarizer import Diarizer
15
 
16
 
17
  class WhisperBase(ABC):
18
  def __init__(self,
19
  model_dir: str,
20
+ output_dir: str,
21
+ args: Namespace
22
  ):
23
  self.model = None
24
  self.current_model_size = None
 
32
  self.device = self.get_device()
33
  self.available_compute_types = ["float16", "float32"]
34
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
+ self.diarizer = Diarizer(
36
+ model_dir=args.diarization_model_dir
37
+ )
38
 
39
  @abstractmethod
40
  def transcribe(self,
 
52
  ):
53
  pass
54
 
55
+ def run(self,
56
+ audio: Union[str, BinaryIO, np.ndarray],
57
+ progress: gr.Progress,
58
+ *whisper_params,
59
+ ) -> Tuple[List[dict], float]:
60
+ """
61
+ Run transcription with conditional post-processing.
62
+ The diarization will be performed in post-processing if enabled.
63
+
64
+ Parameters
65
+ ----------
66
+ audio: Union[str, BinaryIO, np.ndarray]
67
+ Audio input. This can be file path or binary type.
68
+ progress: gr.Progress
69
+ Indicator to show progress directly in gradio.
70
+ *whisper_params: tuple
71
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
72
+
73
+ Returns
74
+ ----------
75
+ segments_result: List[dict]
76
+ list of dicts that includes start, end timestamps and transcribed text
77
+ elapsed_time: float
78
+ elapsed time for running
79
+ """
80
+ params = WhisperParameters.as_value(*whisper_params)
81
+
82
+ if params.lang == "Automatic Detection":
83
+ params.lang = None
84
+ else:
85
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
86
+ params.lang = language_code_dict[params.lang]
87
+
88
+ result, elapsed_time = self.transcribe(
89
+ audio,
90
+ progress,
91
+ *whisper_params
92
+ )
93
+
94
+ if params.is_diarize:
95
+ result, elapsed_time_diarization = self.diarizer.run(
96
+ audio=audio,
97
+ use_auth_token=params.hf_token,
98
+ transcribed_result=result,
99
+ device=self.device
100
+ )
101
+ elapsed_time += elapsed_time_diarization
102
+ return result, elapsed_time
103
+
104
  def transcribe_file(self,
105
  files: list,
106
  file_format: str,
 
122
  progress: gr.Progress
123
  Indicator to show progress directly in gradio.
124
  *whisper_params: tuple
125
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
126
 
127
  Returns
128
  ----------
 
134
  try:
135
  files_info = {}
136
  for file in files:
137
+ transcribed_segments, time_for_task = self.run(
138
  file.name,
139
  progress,
140
  *whisper_params,
 
189
  progress: gr.Progress
190
  Indicator to show progress directly in gradio.
191
  *whisper_params: tuple
192
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
193
 
194
  Returns
195
  ----------
 
200
  """
201
  try:
202
  progress(0, desc="Loading Audio..")
203
+ transcribed_segments, time_for_task = self.run(
204
  mic_audio,
205
  progress,
206
  *whisper_params,
 
244
  progress: gr.Progress
245
  Indicator to show progress directly in gradio.
246
  *whisper_params: tuple
247
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
248
 
249
  Returns
250
  ----------
 
258
  yt = get_ytdata(youtube_link)
259
  audio = get_ytaudio(yt)
260
 
261
+ transcribed_segments, time_for_task = self.run(
262
  audio,
263
  progress,
264
  *whisper_params,
modules/{whisper_parameter.py β†’ whisper/whisper_parameter.py} RENAMED
@@ -27,6 +27,9 @@ class WhisperParameters:
27
  speech_pad_ms: gr.Number
28
  chunk_length_s: gr.Number
29
  batch_size: gr.Number
 
 
 
30
  """
31
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
32
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
@@ -122,9 +125,19 @@ class WhisperParameters:
122
 
123
  batch_size: gr.Number
124
  This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
 
 
 
 
 
 
 
 
 
 
125
  """
126
 
127
- def to_list(self) -> list:
128
  """
129
  Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
130
  See more about Gradio pre-processing: : https://www.gradio.app/docs/components
@@ -136,7 +149,7 @@ class WhisperParameters:
136
  return [getattr(self, f.name) for f in fields(self)]
137
 
138
  @staticmethod
139
- def post_process(*args) -> 'WhisperValues':
140
  """
141
  To use Whisper parameters in function after Gradio post-processing.
142
  See more about Gradio post-processing: : https://www.gradio.app/docs/components
@@ -168,7 +181,10 @@ class WhisperParameters:
168
  window_size_samples=args[18],
169
  speech_pad_ms=args[19],
170
  chunk_length_s=args[20],
171
- batch_size=args[21]
 
 
 
172
  )
173
 
174
 
@@ -196,6 +212,9 @@ class WhisperValues:
196
  speech_pad_ms: int
197
  chunk_length_s: int
198
  batch_size: int
 
 
 
199
  """
200
  A data class to use Whisper parameters.
201
  """
 
27
  speech_pad_ms: gr.Number
28
  chunk_length_s: gr.Number
29
  batch_size: gr.Number
30
+ is_diarize: gr.Checkbox
31
+ hf_token: gr.Textbox
32
+ diarization_device: gr.Dropdown
33
  """
34
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
35
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
 
125
 
126
  batch_size: gr.Number
127
  This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
128
+
129
+ is_diarize: gr.Checkbox
130
+ This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
131
+
132
+ hf_token: gr.Textbox
133
+ This parameter is related with whisperx. Huggingface token is needed to download diarization models.
134
+ Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
135
+
136
+ diarization_device: gr.Dropdown
137
+ This parameter is related with whisperx. Device to run diarization model
138
  """
139
 
140
+ def as_list(self) -> list:
141
  """
142
  Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
143
  See more about Gradio pre-processing: : https://www.gradio.app/docs/components
 
149
  return [getattr(self, f.name) for f in fields(self)]
150
 
151
  @staticmethod
152
+ def as_value(*args) -> 'WhisperValues':
153
  """
154
  To use Whisper parameters in function after Gradio post-processing.
155
  See more about Gradio post-processing: : https://www.gradio.app/docs/components
 
181
  window_size_samples=args[18],
182
  speech_pad_ms=args[19],
183
  chunk_length_s=args[20],
184
+ batch_size=args[21],
185
+ is_diarize=args[22],
186
+ hf_token=args[23],
187
+ diarization_device=args[24]
188
  )
189
 
190
 
 
212
  speech_pad_ms: int
213
  chunk_length_s: int
214
  batch_size: int
215
+ is_diarize: bool
216
+ hf_token: str
217
+ diarization_device: str
218
  """
219
  A data class to use Whisper parameters.
220
  """
requirements.txt CHANGED
@@ -4,4 +4,5 @@ git+https://github.com/jhj0517/jhj0517-whisper.git
4
  faster-whisper==1.0.2
5
  transformers
6
  gradio==4.29.0
7
- pytube
 
 
4
  faster-whisper==1.0.2
5
  transformers
6
  gradio==4.29.0
7
+ pytube
8
+ pyannote.audio==3.3.1