jhj0517 commited on
Commit
d4bc29b
·
unverified ·
2 Parent(s): ffb268e eec0c16

Merge pull request #363 from jhj0517/feature/refactor-models

Browse files
app.py CHANGED
@@ -7,17 +7,14 @@ import yaml
7
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
8
  INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
9
  UVR_MODELS_DIR, I18N_YAML_PATH)
10
- from modules.utils.constants import AUTOMATIC_DETECTION
11
  from modules.utils.files_manager import load_yaml
12
  from modules.whisper.whisper_factory import WhisperFactory
13
- from modules.whisper.faster_whisper_inference import FasterWhisperInference
14
- from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
15
  from modules.translation.nllb_inference import NLLBInference
16
  from modules.ui.htmls import *
17
  from modules.utils.cli_manager import str2bool
18
  from modules.utils.youtube_manager import get_ytmetas
19
  from modules.translation.deepl_api import DeepLAPI
20
- from modules.whisper.whisper_parameter import *
21
 
22
 
23
  class App:
@@ -44,7 +41,7 @@ class App:
44
  print(f"Use \"{self.args.whisper_type}\" implementation\n"
45
  f"Device \"{self.whisper_inf.device}\" is detected")
46
 
47
- def create_whisper_parameters(self):
48
  whisper_params = self.default_params["whisper"]
49
  vad_params = self.default_params["vad"]
50
  diarization_params = self.default_params["diarization"]
@@ -66,158 +63,31 @@ class App:
66
  interactive=True)
67
 
68
  with gr.Accordion(_("Advanced Parameters"), open=False):
69
- nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0,
70
- interactive=True,
71
- info="Beam size to use for decoding.")
72
- nb_log_prob_threshold = gr.Number(label="Log Probability Threshold",
73
- value=whisper_params["log_prob_threshold"], interactive=True,
74
- info="If the average log probability over sampled tokens is below this value, treat as failed.")
75
- nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"],
76
- interactive=True,
77
- info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
78
- dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
79
- value=self.whisper_inf.current_compute_type, interactive=True,
80
- allow_custom_value=True,
81
- info="Select the type of computation to perform.")
82
- nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
83
- info="Number of candidates when sampling with non-zero temperature.")
84
- nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
85
- info="Beam search patience factor.")
86
- cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text",
87
- value=whisper_params["condition_on_previous_text"],
88
- interactive=True,
89
- info="Condition on previous text during decoding.")
90
- sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature",
91
- value=whisper_params["prompt_reset_on_temperature"],
92
- minimum=0, maximum=1, step=0.01, interactive=True,
93
- info="Resets prompt if temperature is above this value."
94
- " Arg has effect only if 'Condition On Previous Text' is True.")
95
- tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
96
- info="Initial prompt to use for decoding.")
97
- sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
98
- step=0.01, maximum=1.0, interactive=True,
99
- info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
100
- nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold",
101
- value=whisper_params["compression_ratio_threshold"],
102
- interactive=True,
103
- info="If the gzip compression ratio is above this value, treat as failed.")
104
- nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
105
- precision=0,
106
- info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
107
- with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
108
- nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
109
- info="Exponential length penalty constant.")
110
- nb_repetition_penalty = gr.Number(label="Repetition Penalty",
111
- value=whisper_params["repetition_penalty"],
112
- info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
113
- nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size",
114
- value=whisper_params["no_repeat_ngram_size"],
115
- precision=0,
116
- info="Prevent repetitions of n-grams with this size (set 0 to disable).")
117
- tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
118
- info="Optional text to provide as a prefix for the first window.")
119
- cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
120
- info="Suppress blank outputs at the beginning of the sampling.")
121
- tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
122
- info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
123
- nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp",
124
- value=whisper_params["max_initial_timestamp"],
125
- info="The initial timestamp cannot be later than this.")
126
- cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
127
- info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
128
- tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations",
129
- value=whisper_params["prepend_punctuations"],
130
- info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
131
- tb_append_punctuations = gr.Textbox(label="Append Punctuations",
132
- value=whisper_params["append_punctuations"],
133
- info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
134
- nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
135
- precision=0,
136
- info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
137
- nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
138
- value=lambda: whisper_params[
139
- "hallucination_silence_threshold"],
140
- info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
141
- tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
142
- info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
143
- nb_language_detection_threshold = gr.Number(label="Language Detection Threshold",
144
- value=lambda: whisper_params[
145
- "language_detection_threshold"],
146
- info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
147
- nb_language_detection_segments = gr.Number(label="Language Detection Segments",
148
- value=lambda: whisper_params["language_detection_segments"],
149
- precision=0,
150
- info="Number of segments to consider for the language detection.")
151
- with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
152
- nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
153
 
154
  with gr.Accordion(_("Background Music Remover Filter"), open=False):
155
- cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"),
156
- value=uvr_params["is_separate_bgm"],
157
- interactive=True,
158
- info=_("Enabling this will remove background music"))
159
- dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device,
160
- choices=self.whisper_inf.music_separator.available_devices)
161
- dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"],
162
- choices=self.whisper_inf.music_separator.available_models)
163
- nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
164
- cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"])
165
- cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"),
166
- value=uvr_params["enable_offload"])
167
 
168
  with gr.Accordion(_("Voice Detection Filter"), open=False):
169
- cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"],
170
- interactive=True,
171
- info=_("Enable this to transcribe only detected voice"))
172
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
173
- value=vad_params["threshold"],
174
- info="Lower it to be more sensitive to small sounds.")
175
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
176
- value=vad_params["min_speech_duration_ms"],
177
- info="Final speech chunks shorter than this time are thrown out")
178
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
179
- value=vad_params["max_speech_duration_s"],
180
- info="Maximum duration of speech chunks in \"seconds\".")
181
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
182
- value=vad_params["min_silence_duration_ms"],
183
- info="In the end of each speech chunk wait for this time"
184
- " before separating it")
185
- nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
186
- info="Final speech chunks are padded by this time each side")
187
 
188
  with gr.Accordion(_("Diarization"), open=False):
189
- cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"])
190
- tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"],
191
- info=_("This is only needed the first time you download the model"))
192
- dd_diarization_device = gr.Dropdown(label=_("Device"),
193
- choices=self.whisper_inf.diarizer.get_available_device(),
194
- value=self.whisper_inf.diarizer.get_device())
195
 
196
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
197
 
 
 
198
  return (
199
- WhisperParameters(
200
- model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
201
- log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
202
- compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
203
- condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
204
- temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
205
- vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
206
- max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
207
- speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
208
- is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
209
- length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
210
- no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
211
- suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
212
- word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
213
- append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
214
- hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
215
- language_detection_threshold=nb_language_detection_threshold,
216
- language_detection_segments=nb_language_detection_segments,
217
- prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
218
- uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
219
- uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
220
- ),
221
  dd_file_format,
222
  cb_timestamp
223
  )
@@ -243,7 +113,7 @@ class App:
243
  visible=self.args.colab,
244
  value="")
245
 
246
- whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
247
 
248
  with gr.Row():
249
  btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -254,7 +124,7 @@ class App:
254
 
255
  params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
256
  btn_run.click(fn=self.whisper_inf.transcribe_file,
257
- inputs=params + whisper_params.as_list(),
258
  outputs=[tb_indicator, files_subtitles])
259
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
260
 
@@ -268,7 +138,7 @@ class App:
268
  tb_title = gr.Label(label=_("Youtube Title"))
269
  tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
270
 
271
- whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
272
 
273
  with gr.Row():
274
  btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -280,7 +150,7 @@ class App:
280
  params = [tb_youtubelink, dd_file_format, cb_timestamp]
281
 
282
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
283
- inputs=params + whisper_params.as_list(),
284
  outputs=[tb_indicator, files_subtitles])
285
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
286
  outputs=[img_thumbnail, tb_title, tb_description])
@@ -290,7 +160,7 @@ class App:
290
  with gr.Row():
291
  mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
292
 
293
- whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
294
 
295
  with gr.Row():
296
  btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -302,7 +172,7 @@ class App:
302
  params = [mic_input, dd_file_format, cb_timestamp]
303
 
304
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
305
- inputs=params + whisper_params.as_list(),
306
  outputs=[tb_indicator, files_subtitles])
307
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
308
 
@@ -417,7 +287,6 @@ class App:
417
 
418
  # Launch the app with optional gradio settings
419
  args = self.args
420
-
421
  self.app.queue(
422
  api_open=args.api_open
423
  ).launch(
@@ -447,8 +316,8 @@ class App:
447
 
448
 
449
  parser = argparse.ArgumentParser()
450
- parser.add_argument('--whisper_type', type=str, default="faster-whisper",
451
- choices=["whisper", "faster-whisper", "insanely-fast-whisper"],
452
  help='A type of the whisper implementation (Github repo name)')
453
  parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
454
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
 
7
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
8
  INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
9
  UVR_MODELS_DIR, I18N_YAML_PATH)
 
10
  from modules.utils.files_manager import load_yaml
11
  from modules.whisper.whisper_factory import WhisperFactory
 
 
12
  from modules.translation.nllb_inference import NLLBInference
13
  from modules.ui.htmls import *
14
  from modules.utils.cli_manager import str2bool
15
  from modules.utils.youtube_manager import get_ytmetas
16
  from modules.translation.deepl_api import DeepLAPI
17
+ from modules.whisper.data_classes import *
18
 
19
 
20
  class App:
 
41
  print(f"Use \"{self.args.whisper_type}\" implementation\n"
42
  f"Device \"{self.whisper_inf.device}\" is detected")
43
 
44
+ def create_pipeline_inputs(self):
45
  whisper_params = self.default_params["whisper"]
46
  vad_params = self.default_params["vad"]
47
  diarization_params = self.default_params["diarization"]
 
63
  interactive=True)
64
 
65
  with gr.Accordion(_("Advanced Parameters"), open=False):
66
+ whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True,
67
+ whisper_type=self.args.whisper_type,
68
+ available_compute_types=self.whisper_inf.available_compute_types,
69
+ compute_type=self.whisper_inf.current_compute_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  with gr.Accordion(_("Background Music Remover Filter"), open=False):
72
+ uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params,
73
+ available_models=self.whisper_inf.music_separator.available_models,
74
+ available_devices=self.whisper_inf.music_separator.available_devices,
75
+ device=self.whisper_inf.music_separator.device)
 
 
 
 
 
 
 
 
76
 
77
  with gr.Accordion(_("Voice Detection Filter"), open=False):
78
+ vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  with gr.Accordion(_("Diarization"), open=False):
81
+ diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params,
82
+ available_devices=self.whisper_inf.diarizer.available_device,
83
+ device=self.whisper_inf.diarizer.device)
 
 
 
84
 
85
  dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
86
 
87
+ pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs
88
+
89
  return (
90
+ pipeline_inputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  dd_file_format,
92
  cb_timestamp
93
  )
 
113
  visible=self.args.colab,
114
  value="")
115
 
116
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
117
 
118
  with gr.Row():
119
  btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
 
124
 
125
  params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
126
  btn_run.click(fn=self.whisper_inf.transcribe_file,
127
+ inputs=params + pipeline_params,
128
  outputs=[tb_indicator, files_subtitles])
129
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
130
 
 
138
  tb_title = gr.Label(label=_("Youtube Title"))
139
  tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
140
 
141
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
142
 
143
  with gr.Row():
144
  btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
 
150
  params = [tb_youtubelink, dd_file_format, cb_timestamp]
151
 
152
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
153
+ inputs=params + pipeline_params,
154
  outputs=[tb_indicator, files_subtitles])
155
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
156
  outputs=[img_thumbnail, tb_title, tb_description])
 
160
  with gr.Row():
161
  mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
162
 
163
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
164
 
165
  with gr.Row():
166
  btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
 
172
  params = [mic_input, dd_file_format, cb_timestamp]
173
 
174
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
175
+ inputs=params + pipeline_params,
176
  outputs=[tb_indicator, files_subtitles])
177
  btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
178
 
 
287
 
288
  # Launch the app with optional gradio settings
289
  args = self.args
 
290
  self.app.queue(
291
  api_open=args.api_open
292
  ).launch(
 
316
 
317
 
318
  parser = argparse.ArgumentParser()
319
+ parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value,
320
+ choices=[item.value for item in WhisperImpl],
321
  help='A type of the whisper implementation (Github repo name)')
322
  parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
323
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
modules/diarize/diarize_pipeline.py CHANGED
@@ -44,6 +44,7 @@ class DiarizationPipeline:
44
  def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
45
  transcript_segments = transcript_result["segments"]
46
  for seg in transcript_segments:
 
47
  # assign speaker to segment (if any)
48
  diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
49
  seg['start'])
 
44
  def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
45
  transcript_segments = transcript_result["segments"]
46
  for seg in transcript_segments:
47
+ seg = seg.dict()
48
  # assign speaker to segment (if any)
49
  diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
50
  seg['start'])
modules/diarize/diarizer.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from typing import List, Union, BinaryIO, Optional
4
  import numpy as np
5
  import time
6
  import logging
@@ -8,6 +8,7 @@ import logging
8
  from modules.utils.paths import DIARIZATION_MODELS_DIR
9
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
10
  from modules.diarize.audio_loader import load_audio
 
11
 
12
 
13
  class Diarizer:
@@ -23,10 +24,10 @@ class Diarizer:
23
 
24
  def run(self,
25
  audio: Union[str, BinaryIO, np.ndarray],
26
- transcribed_result: List[dict],
27
  use_auth_token: str,
28
  device: Optional[str] = None
29
- ):
30
  """
31
  Diarize transcribed result as a post-processing
32
 
@@ -34,7 +35,7 @@ class Diarizer:
34
  ----------
35
  audio: Union[str, BinaryIO, np.ndarray]
36
  Audio input. This can be file path or binary type.
37
- transcribed_result: List[dict]
38
  transcribed result through whisper.
39
  use_auth_token: str
40
  Huggingface token with READ permission. This is only needed the first time you download the model.
@@ -44,8 +45,8 @@ class Diarizer:
44
 
45
  Returns
46
  ----------
47
- segments_result: List[dict]
48
- list of dicts that includes start, end timestamps and transcribed text
49
  elapsed_time: float
50
  elapsed time for running
51
  """
@@ -68,14 +69,21 @@ class Diarizer:
68
  {"segments": transcribed_result}
69
  )
70
 
 
71
  for segment in diarized_result["segments"]:
 
72
  speaker = "None"
73
  if "speaker" in segment:
74
  speaker = segment["speaker"]
75
- segment["text"] = speaker + "|" + segment["text"].strip()
 
 
 
 
 
76
 
77
  elapsed_time = time.time() - start_time
78
- return diarized_result["segments"], elapsed_time
79
 
80
  def update_pipe(self,
81
  use_auth_token: str,
 
1
  import os
2
  import torch
3
+ from typing import List, Union, BinaryIO, Optional, Tuple
4
  import numpy as np
5
  import time
6
  import logging
 
8
  from modules.utils.paths import DIARIZATION_MODELS_DIR
9
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
10
  from modules.diarize.audio_loader import load_audio
11
+ from modules.whisper.data_classes import *
12
 
13
 
14
  class Diarizer:
 
24
 
25
  def run(self,
26
  audio: Union[str, BinaryIO, np.ndarray],
27
+ transcribed_result: List[Segment],
28
  use_auth_token: str,
29
  device: Optional[str] = None
30
+ ) -> Tuple[List[Segment], float]:
31
  """
32
  Diarize transcribed result as a post-processing
33
 
 
35
  ----------
36
  audio: Union[str, BinaryIO, np.ndarray]
37
  Audio input. This can be file path or binary type.
38
+ transcribed_result: List[Segment]
39
  transcribed result through whisper.
40
  use_auth_token: str
41
  Huggingface token with READ permission. This is only needed the first time you download the model.
 
45
 
46
  Returns
47
  ----------
48
+ segments_result: List[Segment]
49
+ list of Segment that includes start, end timestamps and transcribed text
50
  elapsed_time: float
51
  elapsed time for running
52
  """
 
69
  {"segments": transcribed_result}
70
  )
71
 
72
+ segments_result = []
73
  for segment in diarized_result["segments"]:
74
+ segment = segment.dict()
75
  speaker = "None"
76
  if "speaker" in segment:
77
  speaker = segment["speaker"]
78
+ diarized_text = speaker + "|" + segment["text"].strip()
79
+ segments_result.append(Segment(
80
+ start=segment["start"],
81
+ end=segment["end"],
82
+ text=diarized_text
83
+ ))
84
 
85
  elapsed_time = time.time() - start_time
86
+ return segments_result, elapsed_time
87
 
88
  def update_pipe(self,
89
  use_auth_token: str,
modules/translation/translation_base.py CHANGED
@@ -6,7 +6,7 @@ from typing import List
6
  from datetime import datetime
7
 
8
  import modules.translation.nllb_inference as nllb
9
- from modules.whisper.whisper_parameter import *
10
  from modules.utils.subtitle_manager import *
11
  from modules.utils.files_manager import load_yaml, save_yaml
12
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
 
6
  from datetime import datetime
7
 
8
  import modules.translation.nllb_inference as nllb
9
+ from modules.whisper.data_classes import *
10
  from modules.utils.subtitle_manager import *
11
  from modules.utils.files_manager import load_yaml, save_yaml
12
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
modules/utils/constants.py CHANGED
@@ -1,3 +1,6 @@
1
  from gradio_i18n import Translate, gettext as _
2
 
3
  AUTOMATIC_DETECTION = _("Automatic Detection")
 
 
 
 
1
  from gradio_i18n import Translate, gettext as _
2
 
3
  AUTOMATIC_DETECTION = _("Automatic Detection")
4
+ GRADIO_NONE_STR = ""
5
+ GRADIO_NONE_NUMBER_MAX = 9999
6
+ GRADIO_NONE_NUMBER_MIN = 0
modules/utils/subtitle_manager.py CHANGED
@@ -1,5 +1,7 @@
1
  import re
2
 
 
 
3
 
4
  def timeformat_srt(time):
5
  hours = time // 3600
@@ -23,6 +25,9 @@ def write_file(subtitle, output_file):
23
 
24
 
25
  def get_srt(segments):
 
 
 
26
  output = ""
27
  for i, segment in enumerate(segments):
28
  output += f"{i + 1}\n"
@@ -34,6 +39,9 @@ def get_srt(segments):
34
 
35
 
36
  def get_vtt(segments):
 
 
 
37
  output = "WEBVTT\n\n"
38
  for i, segment in enumerate(segments):
39
  output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
@@ -44,6 +52,9 @@ def get_vtt(segments):
44
 
45
 
46
  def get_txt(segments):
 
 
 
47
  output = ""
48
  for i, segment in enumerate(segments):
49
  if segment['text'].startswith(' '):
 
1
  import re
2
 
3
+ from modules.whisper.data_classes import Segment
4
+
5
 
6
  def timeformat_srt(time):
7
  hours = time // 3600
 
25
 
26
 
27
  def get_srt(segments):
28
+ if segments and isinstance(segments[0], Segment):
29
+ segments = [seg.dict() for seg in segments]
30
+
31
  output = ""
32
  for i, segment in enumerate(segments):
33
  output += f"{i + 1}\n"
 
39
 
40
 
41
  def get_vtt(segments):
42
+ if segments and isinstance(segments[0], Segment):
43
+ segments = [seg.dict() for seg in segments]
44
+
45
  output = "WEBVTT\n\n"
46
  for i, segment in enumerate(segments):
47
  output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
 
52
 
53
 
54
  def get_txt(segments):
55
+ if segments and isinstance(segments[0], Segment):
56
+ segments = [seg.dict() for seg in segments]
57
+
58
  output = ""
59
  for i, segment in enumerate(segments):
60
  if segment['text'].startswith(' '):
modules/vad/silero_vad.py CHANGED
@@ -5,7 +5,8 @@ import numpy as np
5
  from typing import BinaryIO, Union, List, Optional, Tuple
6
  import warnings
7
  import faster_whisper
8
- from faster_whisper.transcribe import SpeechTimestampsMap, Segment
 
9
  import gradio as gr
10
 
11
 
@@ -247,18 +248,18 @@ class SileroVAD:
247
 
248
  def restore_speech_timestamps(
249
  self,
250
- segments: List[dict],
251
  speech_chunks: List[dict],
252
  sampling_rate: Optional[int] = None,
253
- ) -> List[dict]:
254
  if sampling_rate is None:
255
  sampling_rate = self.sampling_rate
256
 
257
  ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
258
 
259
  for segment in segments:
260
- segment["start"] = ts_map.get_original_time(segment["start"])
261
- segment["end"] = ts_map.get_original_time(segment["end"])
262
 
263
  return segments
264
 
 
5
  from typing import BinaryIO, Union, List, Optional, Tuple
6
  import warnings
7
  import faster_whisper
8
+ from modules.whisper.data_classes import *
9
+ from faster_whisper.transcribe import SpeechTimestampsMap
10
  import gradio as gr
11
 
12
 
 
248
 
249
  def restore_speech_timestamps(
250
  self,
251
+ segments: List[Segment],
252
  speech_chunks: List[dict],
253
  sampling_rate: Optional[int] = None,
254
+ ) -> List[Segment]:
255
  if sampling_rate is None:
256
  sampling_rate = self.sampling_rate
257
 
258
  ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
259
 
260
  for segment in segments:
261
+ segment.start = ts_map.get_original_time(segment.start)
262
+ segment.start = ts_map.get_original_time(segment.start)
263
 
264
  return segments
265
 
modules/whisper/{whisper_base.py → base_transcription_pipeline.py} RENAMED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  import whisper
4
  import ctranslate2
5
  import gradio as gr
@@ -14,16 +15,16 @@ from dataclasses import astuple
14
  from modules.uvr.music_separator import MusicSeparator
15
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
16
  UVR_MODELS_DIR)
17
- from modules.utils.constants import AUTOMATIC_DETECTION
18
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
19
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
20
  from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
21
- from modules.whisper.whisper_parameter import *
22
  from modules.diarize.diarizer import Diarizer
23
  from modules.vad.silero_vad import SileroVAD
24
 
25
 
26
- class WhisperBase(ABC):
27
  def __init__(self,
28
  model_dir: str = WHISPER_MODELS_DIR,
29
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -74,12 +75,13 @@ class WhisperBase(ABC):
74
  audio: Union[str, BinaryIO, np.ndarray],
75
  progress: gr.Progress = gr.Progress(),
76
  add_timestamp: bool = True,
77
- *whisper_params,
78
  ) -> Tuple[List[dict], float]:
79
  """
80
  Run transcription with conditional pre-processing and post-processing.
81
  The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
82
  The diarization will be performed in post-processing, if enabled.
 
83
 
84
  Parameters
85
  ----------
@@ -89,8 +91,8 @@ class WhisperBase(ABC):
89
  Indicator to show progress directly in gradio.
90
  add_timestamp: bool
91
  Whether to add a timestamp at the end of the filename.
92
- *whisper_params: tuple
93
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
94
 
95
  Returns
96
  ----------
@@ -99,28 +101,17 @@ class WhisperBase(ABC):
99
  elapsed_time: float
100
  elapsed time for running
101
  """
102
- params = WhisperParameters.as_value(*whisper_params)
103
-
104
- self.cache_parameters(
105
- whisper_params=params,
106
- add_timestamp=add_timestamp
107
- )
108
 
109
- if params.lang is None:
110
- pass
111
- elif params.lang == AUTOMATIC_DETECTION:
112
- params.lang = None
113
- else:
114
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
115
- params.lang = language_code_dict[params.lang]
116
-
117
- if params.is_bgm_separate:
118
  music, audio, _ = self.music_separator.separate(
119
  audio=audio,
120
- model_name=params.uvr_model_size,
121
- device=params.uvr_device,
122
- segment_size=params.uvr_segment_size,
123
- save_file=params.uvr_save_file,
124
  progress=progress
125
  )
126
 
@@ -132,47 +123,54 @@ class WhisperBase(ABC):
132
  origin_sample_rate = self.music_separator.audio_info.sample_rate
133
  audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
134
 
135
- if params.uvr_enable_offload:
136
  self.music_separator.offload()
137
 
138
- if params.vad_filter:
139
- # Explicit value set for float('inf') from gr.Number()
140
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
141
- params.max_speech_duration_s = float('inf')
142
-
143
  vad_options = VadOptions(
144
- threshold=params.threshold,
145
- min_speech_duration_ms=params.min_speech_duration_ms,
146
- max_speech_duration_s=params.max_speech_duration_s,
147
- min_silence_duration_ms=params.min_silence_duration_ms,
148
- speech_pad_ms=params.speech_pad_ms
149
  )
150
 
151
- audio, speech_chunks = self.vad.run(
152
  audio=audio,
153
  vad_parameters=vad_options,
154
  progress=progress
155
  )
156
 
 
 
 
 
 
157
  result, elapsed_time = self.transcribe(
158
  audio,
159
  progress,
160
- *astuple(params)
161
  )
162
 
163
- if params.vad_filter:
164
  result = self.vad.restore_speech_timestamps(
165
  segments=result,
166
  speech_chunks=speech_chunks,
167
  )
168
 
169
- if params.is_diarize:
170
  result, elapsed_time_diarization = self.diarizer.run(
171
  audio=audio,
172
- use_auth_token=params.hf_token,
173
  transcribed_result=result,
 
174
  )
175
  elapsed_time += elapsed_time_diarization
 
 
 
 
 
176
  return result, elapsed_time
177
 
178
  def transcribe_file(self,
@@ -181,7 +179,7 @@ class WhisperBase(ABC):
181
  file_format: str = "SRT",
182
  add_timestamp: bool = True,
183
  progress=gr.Progress(),
184
- *whisper_params,
185
  ) -> list:
186
  """
187
  Write subtitle file from Files
@@ -199,8 +197,8 @@ class WhisperBase(ABC):
199
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
200
  progress: gr.Progress
201
  Indicator to show progress directly in gradio.
202
- *whisper_params: tuple
203
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
204
 
205
  Returns
206
  ----------
@@ -223,7 +221,7 @@ class WhisperBase(ABC):
223
  file,
224
  progress,
225
  add_timestamp,
226
- *whisper_params,
227
  )
228
 
229
  file_name, file_ext = os.path.splitext(os.path.basename(file))
@@ -471,7 +469,7 @@ class WhisperBase(ABC):
471
  if torch.cuda.is_available():
472
  return "cuda"
473
  elif torch.backends.mps.is_available():
474
- if not WhisperBase.is_sparse_api_supported():
475
  # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
476
  return "cpu"
477
  return "mps"
@@ -512,18 +510,60 @@ class WhisperBase(ABC):
512
  if file_path and os.path.exists(file_path):
513
  os.remove(file_path)
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  @staticmethod
516
  def cache_parameters(
517
- whisper_params: WhisperValues,
518
  add_timestamp: bool
519
  ):
520
- """cache parameters to the yaml file"""
521
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
522
- cached_whisper_param = whisper_params.to_yaml()
523
- cached_yaml = {**cached_params, **cached_whisper_param}
 
524
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
525
 
526
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  @staticmethod
529
  def resample_audio(audio: Union[str, np.ndarray],
 
1
  import os
2
  import torch
3
+ import ast
4
  import whisper
5
  import ctranslate2
6
  import gradio as gr
 
15
  from modules.uvr.music_separator import MusicSeparator
16
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
  UVR_MODELS_DIR)
18
+ from modules.utils.constants import *
19
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
20
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
21
  from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
22
+ from modules.whisper.data_classes import *
23
  from modules.diarize.diarizer import Diarizer
24
  from modules.vad.silero_vad import SileroVAD
25
 
26
 
27
+ class BaseTranscriptionPipeline(ABC):
28
  def __init__(self,
29
  model_dir: str = WHISPER_MODELS_DIR,
30
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
75
  audio: Union[str, BinaryIO, np.ndarray],
76
  progress: gr.Progress = gr.Progress(),
77
  add_timestamp: bool = True,
78
+ *pipeline_params,
79
  ) -> Tuple[List[dict], float]:
80
  """
81
  Run transcription with conditional pre-processing and post-processing.
82
  The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
83
  The diarization will be performed in post-processing, if enabled.
84
+ Due to the integration with gradio, the parameters have to be specified with a `*` wildcard.
85
 
86
  Parameters
87
  ----------
 
91
  Indicator to show progress directly in gradio.
92
  add_timestamp: bool
93
  Whether to add a timestamp at the end of the filename.
94
+ *pipeline_params: tuple
95
+ Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
96
 
97
  Returns
98
  ----------
 
101
  elapsed_time: float
102
  elapsed time for running
103
  """
104
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
105
+ params = self.validate_gradio_values(params)
106
+ bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
 
 
 
107
 
108
+ if bgm_params.is_separate_bgm:
 
 
 
 
 
 
 
 
109
  music, audio, _ = self.music_separator.separate(
110
  audio=audio,
111
+ model_name=bgm_params.model_size,
112
+ device=bgm_params.device,
113
+ segment_size=bgm_params.segment_size,
114
+ save_file=bgm_params.save_file,
115
  progress=progress
116
  )
117
 
 
123
  origin_sample_rate = self.music_separator.audio_info.sample_rate
124
  audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
125
 
126
+ if bgm_params.enable_offload:
127
  self.music_separator.offload()
128
 
129
+ if vad_params.vad_filter:
 
 
 
 
130
  vad_options = VadOptions(
131
+ threshold=vad_params.threshold,
132
+ min_speech_duration_ms=vad_params.min_speech_duration_ms,
133
+ max_speech_duration_s=vad_params.max_speech_duration_s,
134
+ min_silence_duration_ms=vad_params.min_silence_duration_ms,
135
+ speech_pad_ms=vad_params.speech_pad_ms
136
  )
137
 
138
+ vad_processed, speech_chunks = self.vad.run(
139
  audio=audio,
140
  vad_parameters=vad_options,
141
  progress=progress
142
  )
143
 
144
+ if vad_processed.size > 0:
145
+ audio = vad_processed
146
+ else:
147
+ vad_params.vad_filter = False
148
+
149
  result, elapsed_time = self.transcribe(
150
  audio,
151
  progress,
152
+ *whisper_params.to_list()
153
  )
154
 
155
+ if vad_params.vad_filter:
156
  result = self.vad.restore_speech_timestamps(
157
  segments=result,
158
  speech_chunks=speech_chunks,
159
  )
160
 
161
+ if diarization_params.is_diarize:
162
  result, elapsed_time_diarization = self.diarizer.run(
163
  audio=audio,
164
+ use_auth_token=diarization_params.hf_token,
165
  transcribed_result=result,
166
+ device=diarization_params.device
167
  )
168
  elapsed_time += elapsed_time_diarization
169
+
170
+ self.cache_parameters(
171
+ params=params,
172
+ add_timestamp=add_timestamp
173
+ )
174
  return result, elapsed_time
175
 
176
  def transcribe_file(self,
 
179
  file_format: str = "SRT",
180
  add_timestamp: bool = True,
181
  progress=gr.Progress(),
182
+ *params,
183
  ) -> list:
184
  """
185
  Write subtitle file from Files
 
197
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
198
  progress: gr.Progress
199
  Indicator to show progress directly in gradio.
200
+ *params: tuple
201
+ Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
202
 
203
  Returns
204
  ----------
 
221
  file,
222
  progress,
223
  add_timestamp,
224
+ *params,
225
  )
226
 
227
  file_name, file_ext = os.path.splitext(os.path.basename(file))
 
469
  if torch.cuda.is_available():
470
  return "cuda"
471
  elif torch.backends.mps.is_available():
472
+ if not BaseTranscriptionPipeline.is_sparse_api_supported():
473
  # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
474
  return "cpu"
475
  return "mps"
 
510
  if file_path and os.path.exists(file_path):
511
  os.remove(file_path)
512
 
513
+ @staticmethod
514
+ def validate_gradio_values(params: TranscriptionPipelineParams):
515
+ """
516
+ Validate gradio specific values that can't be displayed as None in the UI.
517
+ Related issue : https://github.com/gradio-app/gradio/issues/8723
518
+ """
519
+ if params.whisper.lang is None:
520
+ pass
521
+ elif params.whisper.lang == AUTOMATIC_DETECTION:
522
+ params.whisper.lang = None
523
+ else:
524
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
525
+ params.whisper.lang = language_code_dict[params.lang]
526
+
527
+ if params.whisper.initial_prompt == GRADIO_NONE_STR:
528
+ params.whisper.initial_prompt = None
529
+ if params.whisper.prefix == GRADIO_NONE_STR:
530
+ params.whisper.prefix = None
531
+ if params.whisper.hotwords == GRADIO_NONE_STR:
532
+ params.whisper.hotwords = None
533
+ if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN:
534
+ params.whisper.max_new_tokens = None
535
+ if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN:
536
+ params.whisper.hallucination_silence_threshold = None
537
+ if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN:
538
+ params.whisper.language_detection_threshold = None
539
+ if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX:
540
+ params.vad.max_speech_duration_s = float('inf')
541
+ return params
542
+
543
  @staticmethod
544
  def cache_parameters(
545
+ params: TranscriptionPipelineParams,
546
  add_timestamp: bool
547
  ):
548
+ """Cache parameters to the yaml file"""
549
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
550
+ param_to_cache = params.to_dict()
551
+
552
+ cached_yaml = {**cached_params, **param_to_cache}
553
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
554
 
555
+ supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
556
+ if supress_token and isinstance(supress_token, list):
557
+ cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
558
+
559
+ if cached_yaml["whisper"].get("lang", None) is None:
560
+ cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
561
+
562
+ if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
563
+ cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
564
+
565
+ if cached_yaml is not None and cached_yaml:
566
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
567
 
568
  @staticmethod
569
  def resample_audio(audio: Union[str, np.ndarray],
modules/whisper/data_classes.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from typing import Optional, Dict, List, Union
4
+ from pydantic import BaseModel, Field, field_validator, ConfigDict
5
+ from gradio_i18n import Translate, gettext as _
6
+ from enum import Enum
7
+ from copy import deepcopy
8
+ import yaml
9
+
10
+ from modules.utils.constants import *
11
+
12
+
13
+ class WhisperImpl(Enum):
14
+ WHISPER = "whisper"
15
+ FASTER_WHISPER = "faster-whisper"
16
+ INSANELY_FAST_WHISPER = "insanely_fast_whisper"
17
+
18
+
19
+ class Segment(BaseModel):
20
+ text: Optional[str] = Field(default=None,
21
+ description="Transcription text of the segment")
22
+ start: Optional[float] = Field(default=None,
23
+ description="Start time of the segment")
24
+ end: Optional[float] = Field(default=None,
25
+ description="End time of the segment")
26
+
27
+
28
+ class BaseParams(BaseModel):
29
+ model_config = ConfigDict(protected_namespaces=())
30
+
31
+ def to_dict(self) -> Dict:
32
+ return self.model_dump()
33
+
34
+ def to_list(self) -> List:
35
+ return list(self.model_dump().values())
36
+
37
+ @classmethod
38
+ def from_list(cls, data_list: List) -> 'BaseParams':
39
+ field_names = list(cls.model_fields.keys())
40
+ return cls(**dict(zip(field_names, data_list)))
41
+
42
+
43
+ class VadParams(BaseParams):
44
+ """Voice Activity Detection parameters"""
45
+ vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
46
+ threshold: float = Field(
47
+ default=0.5,
48
+ ge=0.0,
49
+ le=1.0,
50
+ description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
51
+ )
52
+ min_speech_duration_ms: int = Field(
53
+ default=250,
54
+ ge=0,
55
+ description="Final speech chunks shorter than this are discarded"
56
+ )
57
+ max_speech_duration_s: float = Field(
58
+ default=float("inf"),
59
+ gt=0,
60
+ description="Maximum duration of speech chunks in seconds"
61
+ )
62
+ min_silence_duration_ms: int = Field(
63
+ default=2000,
64
+ ge=0,
65
+ description="Minimum silence duration between speech chunks"
66
+ )
67
+ speech_pad_ms: int = Field(
68
+ default=400,
69
+ ge=0,
70
+ description="Padding added to each side of speech chunks"
71
+ )
72
+
73
+ @classmethod
74
+ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
75
+ return [
76
+ gr.Checkbox(
77
+ label=_("Enable Silero VAD Filter"),
78
+ value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
79
+ interactive=True,
80
+ info=_("Enable this to transcribe only detected voice")
81
+ ),
82
+ gr.Slider(
83
+ minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
84
+ value=defaults.get("threshold", cls.__fields__["threshold"].default),
85
+ info="Lower it to be more sensitive to small sounds."
86
+ ),
87
+ gr.Number(
88
+ label="Minimum Speech Duration (ms)", precision=0,
89
+ value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
90
+ info="Final speech chunks shorter than this time are thrown out"
91
+ ),
92
+ gr.Number(
93
+ label="Maximum Speech Duration (s)",
94
+ value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
95
+ info="Maximum duration of speech chunks in \"seconds\"."
96
+ ),
97
+ gr.Number(
98
+ label="Minimum Silence Duration (ms)", precision=0,
99
+ value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
100
+ info="In the end of each speech chunk wait for this time before separating it"
101
+ ),
102
+ gr.Number(
103
+ label="Speech Padding (ms)", precision=0,
104
+ value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
105
+ info="Final speech chunks are padded by this time each side"
106
+ )
107
+ ]
108
+
109
+
110
+ class DiarizationParams(BaseParams):
111
+ """Speaker diarization parameters"""
112
+ is_diarize: bool = Field(default=False, description="Enable speaker diarization")
113
+ device: str = Field(default="cuda", description="Device to run Diarization model.")
114
+ hf_token: str = Field(
115
+ default="",
116
+ description="Hugging Face token for downloading diarization models"
117
+ )
118
+
119
+ @classmethod
120
+ def to_gradio_inputs(cls,
121
+ defaults: Optional[Dict] = None,
122
+ available_devices: Optional[List] = None,
123
+ device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
124
+ return [
125
+ gr.Checkbox(
126
+ label=_("Enable Diarization"),
127
+ value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
128
+ ),
129
+ gr.Dropdown(
130
+ label=_("Device"),
131
+ choices=["cpu", "cuda"] if available_devices is None else available_devices,
132
+ value=defaults.get("device", device),
133
+ ),
134
+ gr.Textbox(
135
+ label=_("HuggingFace Token"),
136
+ value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
137
+ info=_("This is only needed the first time you download the model")
138
+ ),
139
+ ]
140
+
141
+
142
+ class BGMSeparationParams(BaseParams):
143
+ """Background music separation parameters"""
144
+ is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
145
+ model_size: str = Field(
146
+ default="UVR-MDX-NET-Inst_HQ_4",
147
+ description="UVR model size"
148
+ )
149
+ device: str = Field(default="cuda", description="Device to run UVR model.")
150
+ segment_size: int = Field(
151
+ default=256,
152
+ gt=0,
153
+ description="Segment size for UVR model"
154
+ )
155
+ save_file: bool = Field(
156
+ default=False,
157
+ description="Whether to save separated audio files"
158
+ )
159
+ enable_offload: bool = Field(
160
+ default=True,
161
+ description="Offload UVR model after transcription"
162
+ )
163
+
164
+ @classmethod
165
+ def to_gradio_input(cls,
166
+ defaults: Optional[Dict] = None,
167
+ available_devices: Optional[List] = None,
168
+ device: Optional[str] = None,
169
+ available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
170
+ return [
171
+ gr.Checkbox(
172
+ label=_("Enable Background Music Remover Filter"),
173
+ value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
174
+ interactive=True,
175
+ info=_("Enabling this will remove background music")
176
+ ),
177
+ gr.Dropdown(
178
+ label=_("Model"),
179
+ choices=["UVR-MDX-NET-Inst_HQ_4",
180
+ "UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
181
+ value=defaults.get("model_size", cls.__fields__["model_size"].default),
182
+ ),
183
+ gr.Dropdown(
184
+ label=_("Device"),
185
+ choices=["cpu", "cuda"] if available_devices is None else available_devices,
186
+ value=defaults.get("device", device),
187
+ ),
188
+ gr.Number(
189
+ label="Segment Size",
190
+ value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
191
+ precision=0,
192
+ info="Segment size for UVR model"
193
+ ),
194
+ gr.Checkbox(
195
+ label=_("Save separated files to output"),
196
+ value=defaults.get("save_file", cls.__fields__["save_file"].default),
197
+ ),
198
+ gr.Checkbox(
199
+ label=_("Offload sub model after removing background music"),
200
+ value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
201
+ )
202
+ ]
203
+
204
+
205
+ class WhisperParams(BaseParams):
206
+ """Whisper parameters"""
207
+ model_size: str = Field(default="large-v2", description="Whisper model size")
208
+ lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
209
+ is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
210
+ beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
211
+ log_prob_threshold: float = Field(
212
+ default=-1.0,
213
+ description="Threshold for average log probability of sampled tokens"
214
+ )
215
+ no_speech_threshold: float = Field(
216
+ default=0.6,
217
+ ge=0.0,
218
+ le=1.0,
219
+ description="Threshold for detecting silence"
220
+ )
221
+ compute_type: str = Field(default="float16", description="Computation type for transcription")
222
+ best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
223
+ patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
224
+ condition_on_previous_text: bool = Field(
225
+ default=True,
226
+ description="Use previous output as prompt for next window"
227
+ )
228
+ prompt_reset_on_temperature: float = Field(
229
+ default=0.5,
230
+ ge=0.0,
231
+ le=1.0,
232
+ description="Temperature threshold for resetting prompt"
233
+ )
234
+ initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
235
+ temperature: float = Field(
236
+ default=0.0,
237
+ ge=0.0,
238
+ description="Temperature for sampling"
239
+ )
240
+ compression_ratio_threshold: float = Field(
241
+ default=2.4,
242
+ gt=0,
243
+ description="Threshold for gzip compression ratio"
244
+ )
245
+ length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
246
+ repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
247
+ no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
248
+ prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
249
+ suppress_blank: bool = Field(
250
+ default=True,
251
+ description="Suppress blank outputs at start of sampling"
252
+ )
253
+ suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
254
+ max_initial_timestamp: float = Field(
255
+ default=0.0,
256
+ ge=0.0,
257
+ description="Maximum initial timestamp"
258
+ )
259
+ word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
260
+ prepend_punctuations: Optional[str] = Field(
261
+ default="\"'“¿([{-",
262
+ description="Punctuations to merge with next word"
263
+ )
264
+ append_punctuations: Optional[str] = Field(
265
+ default="\"'.。,,!!??::”)]}、",
266
+ description="Punctuations to merge with previous word"
267
+ )
268
+ max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
269
+ chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
270
+ hallucination_silence_threshold: Optional[float] = Field(
271
+ default=None,
272
+ description="Threshold for skipping silent periods in hallucination detection"
273
+ )
274
+ hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
275
+ language_detection_threshold: Optional[float] = Field(
276
+ default=None,
277
+ description="Threshold for language detection probability"
278
+ )
279
+ language_detection_segments: int = Field(
280
+ default=1,
281
+ gt=0,
282
+ description="Number of segments for language detection"
283
+ )
284
+ batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
285
+
286
+ @field_validator('lang')
287
+ def validate_lang(cls, v):
288
+ from modules.utils.constants import AUTOMATIC_DETECTION
289
+ return None if v == AUTOMATIC_DETECTION.unwrap() else v
290
+
291
+ @field_validator('suppress_tokens')
292
+ def validate_supress_tokens(cls, v):
293
+ import ast
294
+ try:
295
+ if isinstance(v, str):
296
+ suppress_tokens = ast.literal_eval(v)
297
+ if not isinstance(suppress_tokens, list):
298
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
299
+ return suppress_tokens
300
+ if isinstance(v, list):
301
+ return v
302
+ except Exception as e:
303
+ raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
304
+
305
+ @classmethod
306
+ def to_gradio_inputs(cls,
307
+ defaults: Optional[Dict] = None,
308
+ only_advanced: Optional[bool] = True,
309
+ whisper_type: Optional[str] = None,
310
+ available_models: Optional[List] = None,
311
+ available_langs: Optional[List] = None,
312
+ available_compute_types: Optional[List] = None,
313
+ compute_type: Optional[str] = None):
314
+ whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower()
315
+
316
+ inputs = []
317
+ if not only_advanced:
318
+ inputs += [
319
+ gr.Dropdown(
320
+ label=_("Model"),
321
+ choices=available_models,
322
+ value=defaults.get("model_size", cls.__fields__["model_size"].default),
323
+ ),
324
+ gr.Dropdown(
325
+ label=_("Language"),
326
+ choices=available_langs,
327
+ value=defaults.get("lang", AUTOMATIC_DETECTION),
328
+ ),
329
+ gr.Checkbox(
330
+ label=_("Translate to English?"),
331
+ value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
332
+ ),
333
+ ]
334
+
335
+ inputs += [
336
+ gr.Number(
337
+ label="Beam Size",
338
+ value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
339
+ precision=0,
340
+ info="Beam size for decoding"
341
+ ),
342
+ gr.Number(
343
+ label="Log Probability Threshold",
344
+ value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
345
+ info="Threshold for average log probability of sampled tokens"
346
+ ),
347
+ gr.Number(
348
+ label="No Speech Threshold",
349
+ value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
350
+ info="Threshold for detecting silence"
351
+ ),
352
+ gr.Dropdown(
353
+ label="Compute Type",
354
+ choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types,
355
+ value=defaults.get("compute_type", compute_type),
356
+ info="Computation type for transcription"
357
+ ),
358
+ gr.Number(
359
+ label="Best Of",
360
+ value=defaults.get("best_of", cls.__fields__["best_of"].default),
361
+ precision=0,
362
+ info="Number of candidates when sampling"
363
+ ),
364
+ gr.Number(
365
+ label="Patience",
366
+ value=defaults.get("patience", cls.__fields__["patience"].default),
367
+ info="Beam search patience factor"
368
+ ),
369
+ gr.Checkbox(
370
+ label="Condition On Previous Text",
371
+ value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
372
+ info="Use previous output as prompt for next window"
373
+ ),
374
+ gr.Slider(
375
+ label="Prompt Reset On Temperature",
376
+ value=defaults.get("prompt_reset_on_temperature",
377
+ cls.__fields__["prompt_reset_on_temperature"].default),
378
+ minimum=0,
379
+ maximum=1,
380
+ step=0.01,
381
+ info="Temperature threshold for resetting prompt"
382
+ ),
383
+ gr.Textbox(
384
+ label="Initial Prompt",
385
+ value=defaults.get("initial_prompt", GRADIO_NONE_STR),
386
+ info="Initial prompt for first window"
387
+ ),
388
+ gr.Slider(
389
+ label="Temperature",
390
+ value=defaults.get("temperature", cls.__fields__["temperature"].default),
391
+ minimum=0.0,
392
+ step=0.01,
393
+ maximum=1.0,
394
+ info="Temperature for sampling"
395
+ ),
396
+ gr.Number(
397
+ label="Compression Ratio Threshold",
398
+ value=defaults.get("compression_ratio_threshold",
399
+ cls.__fields__["compression_ratio_threshold"].default),
400
+ info="Threshold for gzip compression ratio"
401
+ )
402
+ ]
403
+
404
+ faster_whisper_inputs = [
405
+ gr.Number(
406
+ label="Length Penalty",
407
+ value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
408
+ info="Exponential length penalty",
409
+ ),
410
+ gr.Number(
411
+ label="Repetition Penalty",
412
+ value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
413
+ info="Penalty for repeated tokens"
414
+ ),
415
+ gr.Number(
416
+ label="No Repeat N-gram Size",
417
+ value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
418
+ precision=0,
419
+ info="Size of n-grams to prevent repetition"
420
+ ),
421
+ gr.Textbox(
422
+ label="Prefix",
423
+ value=defaults.get("prefix", GRADIO_NONE_STR),
424
+ info="Prefix text for first window"
425
+ ),
426
+ gr.Checkbox(
427
+ label="Suppress Blank",
428
+ value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
429
+ info="Suppress blank outputs at start of sampling"
430
+ ),
431
+ gr.Textbox(
432
+ label="Suppress Tokens",
433
+ value=defaults.get("suppress_tokens", "[-1]"),
434
+ info="Token IDs to suppress"
435
+ ),
436
+ gr.Number(
437
+ label="Max Initial Timestamp",
438
+ value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
439
+ info="Maximum initial timestamp"
440
+ ),
441
+ gr.Checkbox(
442
+ label="Word Timestamps",
443
+ value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
444
+ info="Extract word-level timestamps"
445
+ ),
446
+ gr.Textbox(
447
+ label="Prepend Punctuations",
448
+ value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
449
+ info="Punctuations to merge with next word"
450
+ ),
451
+ gr.Textbox(
452
+ label="Append Punctuations",
453
+ value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
454
+ info="Punctuations to merge with previous word"
455
+ ),
456
+ gr.Number(
457
+ label="Max New Tokens",
458
+ value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
459
+ precision=0,
460
+ info="Maximum number of new tokens per chunk"
461
+ ),
462
+ gr.Number(
463
+ label="Chunk Length (s)",
464
+ value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
465
+ precision=0,
466
+ info="Length of audio segments in seconds"
467
+ ),
468
+ gr.Number(
469
+ label="Hallucination Silence Threshold (sec)",
470
+ value=defaults.get("hallucination_silence_threshold",
471
+ GRADIO_NONE_NUMBER_MIN),
472
+ info="Threshold for skipping silent periods in hallucination detection"
473
+ ),
474
+ gr.Textbox(
475
+ label="Hotwords",
476
+ value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
477
+ info="Hotwords/hint phrases for the model"
478
+ ),
479
+ gr.Number(
480
+ label="Language Detection Threshold",
481
+ value=defaults.get("language_detection_threshold",
482
+ GRADIO_NONE_NUMBER_MIN),
483
+ info="Threshold for language detection probability"
484
+ ),
485
+ gr.Number(
486
+ label="Language Detection Segments",
487
+ value=defaults.get("language_detection_segments",
488
+ cls.__fields__["language_detection_segments"].default),
489
+ precision=0,
490
+ info="Number of segments for language detection"
491
+ )
492
+ ]
493
+
494
+ insanely_fast_whisper_inputs = [
495
+ gr.Number(
496
+ label="Batch Size",
497
+ value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
498
+ precision=0,
499
+ info="Batch size for processing"
500
+ )
501
+ ]
502
+
503
+ if whisper_type != WhisperImpl.FASTER_WHISPER.value:
504
+ for input_component in faster_whisper_inputs:
505
+ input_component.visible = False
506
+
507
+ if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value:
508
+ for input_component in insanely_fast_whisper_inputs:
509
+ input_component.visible = False
510
+
511
+ inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
512
+
513
+ return inputs
514
+
515
+
516
+ class TranscriptionPipelineParams(BaseModel):
517
+ """Transcription pipeline parameters"""
518
+ whisper: WhisperParams = Field(default_factory=WhisperParams)
519
+ vad: VadParams = Field(default_factory=VadParams)
520
+ diarization: DiarizationParams = Field(default_factory=DiarizationParams)
521
+ bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
522
+
523
+ def to_dict(self) -> Dict:
524
+ data = {
525
+ "whisper": self.whisper.to_dict(),
526
+ "vad": self.vad.to_dict(),
527
+ "diarization": self.diarization.to_dict(),
528
+ "bgm_separation": self.bgm_separation.to_dict()
529
+ }
530
+ return data
531
+
532
+ def to_list(self) -> List:
533
+ """
534
+ Convert data class to the list because I have to pass the parameters as a list in the gradio.
535
+ Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
536
+ See more about Gradio pre-processing: https://www.gradio.app/docs/components
537
+ """
538
+ whisper_list = self.whisper.to_list()
539
+ vad_list = self.vad.to_list()
540
+ diarization_list = self.diarization.to_list()
541
+ bgm_sep_list = self.bgm_separation.to_list()
542
+ return whisper_list + vad_list + diarization_list + bgm_sep_list
543
+
544
+ @staticmethod
545
+ def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
546
+ """Convert list to the data class again to use it in a function."""
547
+ data_list = deepcopy(pipeline_list)
548
+
549
+ whisper_list = data_list[0:len(WhisperParams.__annotations__)]
550
+ data_list = data_list[len(WhisperParams.__annotations__):]
551
+
552
+ vad_list = data_list[0:len(VadParams.__annotations__)]
553
+ data_list = data_list[len(VadParams.__annotations__):]
554
+
555
+ diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
556
+ data_list = data_list[len(DiarizationParams.__annotations__):]
557
+
558
+ bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
559
+
560
+ return TranscriptionPipelineParams(
561
+ whisper=WhisperParams.from_list(whisper_list),
562
+ vad=VadParams.from_list(vad_list),
563
+ diarization=DiarizationParams.from_list(diarization_list),
564
+ bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
565
+ )
modules/whisper/faster_whisper_inference.py CHANGED
@@ -12,11 +12,11 @@ import gradio as gr
12
  from argparse import Namespace
13
 
14
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
- from modules.whisper.whisper_parameter import *
16
- from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
- class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
  model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
40
  audio: Union[str, BinaryIO, np.ndarray],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
- ) -> Tuple[List[dict], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
@@ -55,28 +55,18 @@ class FasterWhisperInference(WhisperBase):
55
 
56
  Returns
57
  ----------
58
- segments_result: List[dict]
59
- list of dicts that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
63
  start_time = time.time()
64
 
65
- params = WhisperParameters.as_value(*whisper_params)
66
 
67
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
68
  self.update_model(params.model_size, params.compute_type, progress)
69
 
70
- # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
71
- if not params.initial_prompt:
72
- params.initial_prompt = None
73
- if not params.prefix:
74
- params.prefix = None
75
- if not params.hotwords:
76
- params.hotwords = None
77
-
78
- params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
79
-
80
  segments, info = self.model.transcribe(
81
  audio=audio,
82
  language=params.lang,
@@ -112,11 +102,11 @@ class FasterWhisperInference(WhisperBase):
112
  segments_result = []
113
  for segment in segments:
114
  progress(segment.start / info.duration, desc="Transcribing..")
115
- segments_result.append({
116
- "start": segment.start,
117
- "end": segment.end,
118
- "text": segment.text
119
- })
120
 
121
  elapsed_time = time.time() - start_time
122
  return segments_result, elapsed_time
 
12
  from argparse import Namespace
13
 
14
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
+ from modules.whisper.data_classes import *
16
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
17
 
18
 
19
+ class FasterWhisperInference(BaseTranscriptionPipeline):
20
  def __init__(self,
21
  model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
40
  audio: Union[str, BinaryIO, np.ndarray],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
+ ) -> Tuple[List[Segment], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
 
55
 
56
  Returns
57
  ----------
58
+ segments_result: List[Segment]
59
+ list of Segment that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
63
  start_time = time.time()
64
 
65
+ params = WhisperParams.from_list(list(whisper_params))
66
 
67
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
68
  self.update_model(params.model_size, params.compute_type, progress)
69
 
 
 
 
 
 
 
 
 
 
 
70
  segments, info = self.model.transcribe(
71
  audio=audio,
72
  language=params.lang,
 
102
  segments_result = []
103
  for segment in segments:
104
  progress(segment.start / info.duration, desc="Transcribing..")
105
+ segments_result.append(Segment(
106
+ start=segment.start,
107
+ end=segment.end,
108
+ text=segment.text
109
+ ))
110
 
111
  elapsed_time = time.time() - start_time
112
  return segments_result, elapsed_time
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -12,11 +12,11 @@ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
  from argparse import Namespace
13
 
14
  from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
- from modules.whisper.whisper_parameter import *
16
- from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
- class InsanelyFastWhisperInference(WhisperBase):
20
  def __init__(self,
21
  model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -40,7 +40,7 @@ class InsanelyFastWhisperInference(WhisperBase):
40
  audio: Union[str, np.ndarray, torch.Tensor],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
- ) -> Tuple[List[dict], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
@@ -55,13 +55,13 @@ class InsanelyFastWhisperInference(WhisperBase):
55
 
56
  Returns
57
  ----------
58
- segments_result: List[dict]
59
- list of dicts that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
63
  start_time = time.time()
64
- params = WhisperParameters.as_value(*whisper_params)
65
 
66
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
67
  self.update_model(params.model_size, params.compute_type, progress)
@@ -95,9 +95,17 @@ class InsanelyFastWhisperInference(WhisperBase):
95
  generate_kwargs=kwargs
96
  )
97
 
98
- segments_result = self.format_result(
99
- transcribed_result=segments,
100
- )
 
 
 
 
 
 
 
 
101
  elapsed_time = time.time() - start_time
102
  return segments_result, elapsed_time
103
 
 
12
  from argparse import Namespace
13
 
14
  from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
+ from modules.whisper.data_classes import *
16
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
17
 
18
 
19
+ class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
20
  def __init__(self,
21
  model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
40
  audio: Union[str, np.ndarray, torch.Tensor],
41
  progress: gr.Progress = gr.Progress(),
42
  *whisper_params,
43
+ ) -> Tuple[List[Segment], float]:
44
  """
45
  transcribe method for faster-whisper.
46
 
 
55
 
56
  Returns
57
  ----------
58
+ segments_result: List[Segment]
59
+ list of Segment that includes start, end timestamps and transcribed text
60
  elapsed_time: float
61
  elapsed time for transcription
62
  """
63
  start_time = time.time()
64
+ params = WhisperParams.from_list(list(whisper_params))
65
 
66
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
67
  self.update_model(params.model_size, params.compute_type, progress)
 
95
  generate_kwargs=kwargs
96
  )
97
 
98
+ segments_result = []
99
+ for item in segments["chunks"]:
100
+ start, end = item["timestamp"][0], item["timestamp"][1]
101
+ if end is None:
102
+ end = start
103
+ segments_result.append(Segment(
104
+ text=item["text"],
105
+ start=start,
106
+ end=end
107
+ ))
108
+
109
  elapsed_time = time.time() - start_time
110
  return segments_result, elapsed_time
111
 
modules/whisper/whisper_Inference.py CHANGED
@@ -8,11 +8,11 @@ import os
8
  from argparse import Namespace
9
 
10
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
11
- from modules.whisper.whisper_base import WhisperBase
12
- from modules.whisper.whisper_parameter import *
13
 
14
 
15
- class WhisperInference(WhisperBase):
16
  def __init__(self,
17
  model_dir: str = WHISPER_MODELS_DIR,
18
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -30,7 +30,7 @@ class WhisperInference(WhisperBase):
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
  progress: gr.Progress = gr.Progress(),
32
  *whisper_params,
33
- ) -> Tuple[List[dict], float]:
34
  """
35
  transcribe method for faster-whisper.
36
 
@@ -45,13 +45,13 @@ class WhisperInference(WhisperBase):
45
 
46
  Returns
47
  ----------
48
- segments_result: List[dict]
49
- list of dicts that includes start, end timestamps and transcribed text
50
  elapsed_time: float
51
  elapsed time for transcription
52
  """
53
  start_time = time.time()
54
- params = WhisperParameters.as_value(*whisper_params)
55
 
56
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
57
  self.update_model(params.model_size, params.compute_type, progress)
@@ -59,21 +59,28 @@ class WhisperInference(WhisperBase):
59
  def progress_callback(progress_value):
60
  progress(progress_value, desc="Transcribing..")
61
 
62
- segments_result = self.model.transcribe(audio=audio,
63
- language=params.lang,
64
- verbose=False,
65
- beam_size=params.beam_size,
66
- logprob_threshold=params.log_prob_threshold,
67
- no_speech_threshold=params.no_speech_threshold,
68
- task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
69
- fp16=True if params.compute_type == "float16" else False,
70
- best_of=params.best_of,
71
- patience=params.patience,
72
- temperature=params.temperature,
73
- compression_ratio_threshold=params.compression_ratio_threshold,
74
- progress_callback=progress_callback,)["segments"]
75
- elapsed_time = time.time() - start_time
 
 
 
 
 
 
76
 
 
77
  return segments_result, elapsed_time
78
 
79
  def update_model(self,
 
8
  from argparse import Namespace
9
 
10
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
11
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
12
+ from modules.whisper.data_classes import *
13
 
14
 
15
+ class WhisperInference(BaseTranscriptionPipeline):
16
  def __init__(self,
17
  model_dir: str = WHISPER_MODELS_DIR,
18
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
  progress: gr.Progress = gr.Progress(),
32
  *whisper_params,
33
+ ) -> Tuple[List[Segment], float]:
34
  """
35
  transcribe method for faster-whisper.
36
 
 
45
 
46
  Returns
47
  ----------
48
+ segments_result: List[Segment]
49
+ list of Segment that includes start, end timestamps and transcribed text
50
  elapsed_time: float
51
  elapsed time for transcription
52
  """
53
  start_time = time.time()
54
+ params = WhisperParams.from_list(list(whisper_params))
55
 
56
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
57
  self.update_model(params.model_size, params.compute_type, progress)
 
59
  def progress_callback(progress_value):
60
  progress(progress_value, desc="Transcribing..")
61
 
62
+ result = self.model.transcribe(audio=audio,
63
+ language=params.lang,
64
+ verbose=False,
65
+ beam_size=params.beam_size,
66
+ logprob_threshold=params.log_prob_threshold,
67
+ no_speech_threshold=params.no_speech_threshold,
68
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
69
+ fp16=True if params.compute_type == "float16" else False,
70
+ best_of=params.best_of,
71
+ patience=params.patience,
72
+ temperature=params.temperature,
73
+ compression_ratio_threshold=params.compression_ratio_threshold,
74
+ progress_callback=progress_callback,)["segments"]
75
+ segments_result = []
76
+ for segment in result:
77
+ segments_result.append(Segment(
78
+ start=segment["start"],
79
+ end=segment["end"],
80
+ text=segment["text"]
81
+ ))
82
 
83
+ elapsed_time = time.time() - start_time
84
  return segments_result, elapsed_time
85
 
86
  def update_model(self,
modules/whisper/whisper_factory.py CHANGED
@@ -6,7 +6,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.whisper_Inference import WhisperInference
8
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
9
- from modules.whisper.whisper_base import WhisperBase
 
10
 
11
 
12
  class WhisperFactory:
@@ -19,7 +20,7 @@ class WhisperFactory:
19
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
20
  uvr_model_dir: str = UVR_MODELS_DIR,
21
  output_dir: str = OUTPUT_DIR,
22
- ) -> "WhisperBase":
23
  """
24
  Create a whisper inference class based on the provided whisper_type.
25
 
@@ -45,36 +46,29 @@ class WhisperFactory:
45
 
46
  Returns
47
  -------
48
- WhisperBase
49
  An instance of the appropriate whisper inference class based on the whisper_type.
50
  """
51
  # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
52
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
53
 
54
- whisper_type = whisper_type.lower().strip()
55
 
56
- faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
57
- whisper_typos = ["whisper"]
58
- insanely_fast_whisper_typos = [
59
- "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
60
- "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
61
- ]
62
-
63
- if whisper_type in faster_whisper_typos:
64
  return FasterWhisperInference(
65
  model_dir=faster_whisper_model_dir,
66
  output_dir=output_dir,
67
  diarization_model_dir=diarization_model_dir,
68
  uvr_model_dir=uvr_model_dir
69
  )
70
- elif whisper_type in whisper_typos:
71
  return WhisperInference(
72
  model_dir=whisper_model_dir,
73
  output_dir=output_dir,
74
  diarization_model_dir=diarization_model_dir,
75
  uvr_model_dir=uvr_model_dir
76
  )
77
- elif whisper_type in insanely_fast_whisper_typos:
78
  return InsanelyFastWhisperInference(
79
  model_dir=insanely_fast_whisper_model_dir,
80
  output_dir=output_dir,
 
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.whisper_Inference import WhisperInference
8
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
9
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
10
+ from modules.whisper.data_classes import *
11
 
12
 
13
  class WhisperFactory:
 
20
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
21
  uvr_model_dir: str = UVR_MODELS_DIR,
22
  output_dir: str = OUTPUT_DIR,
23
+ ) -> "BaseTranscriptionPipeline":
24
  """
25
  Create a whisper inference class based on the provided whisper_type.
26
 
 
46
 
47
  Returns
48
  -------
49
+ BaseTranscriptionPipeline
50
  An instance of the appropriate whisper inference class based on the whisper_type.
51
  """
52
  # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
53
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
54
 
55
+ whisper_type = whisper_type.strip().lower()
56
 
57
+ if whisper_type == WhisperImpl.FASTER_WHISPER.value:
 
 
 
 
 
 
 
58
  return FasterWhisperInference(
59
  model_dir=faster_whisper_model_dir,
60
  output_dir=output_dir,
61
  diarization_model_dir=diarization_model_dir,
62
  uvr_model_dir=uvr_model_dir
63
  )
64
+ elif whisper_type == WhisperImpl.WHISPER.value:
65
  return WhisperInference(
66
  model_dir=whisper_model_dir,
67
  output_dir=output_dir,
68
  diarization_model_dir=diarization_model_dir,
69
  uvr_model_dir=uvr_model_dir
70
  )
71
+ elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value:
72
  return InsanelyFastWhisperInference(
73
  model_dir=insanely_fast_whisper_model_dir,
74
  output_dir=output_dir,
modules/whisper/whisper_parameter.py DELETED
@@ -1,371 +0,0 @@
1
- from dataclasses import dataclass, fields
2
- import gradio as gr
3
- from typing import Optional, Dict
4
- import yaml
5
-
6
- from modules.utils.constants import AUTOMATIC_DETECTION
7
-
8
-
9
- @dataclass
10
- class WhisperParameters:
11
- model_size: gr.Dropdown
12
- lang: gr.Dropdown
13
- is_translate: gr.Checkbox
14
- beam_size: gr.Number
15
- log_prob_threshold: gr.Number
16
- no_speech_threshold: gr.Number
17
- compute_type: gr.Dropdown
18
- best_of: gr.Number
19
- patience: gr.Number
20
- condition_on_previous_text: gr.Checkbox
21
- prompt_reset_on_temperature: gr.Slider
22
- initial_prompt: gr.Textbox
23
- temperature: gr.Slider
24
- compression_ratio_threshold: gr.Number
25
- vad_filter: gr.Checkbox
26
- threshold: gr.Slider
27
- min_speech_duration_ms: gr.Number
28
- max_speech_duration_s: gr.Number
29
- min_silence_duration_ms: gr.Number
30
- speech_pad_ms: gr.Number
31
- batch_size: gr.Number
32
- is_diarize: gr.Checkbox
33
- hf_token: gr.Textbox
34
- diarization_device: gr.Dropdown
35
- length_penalty: gr.Number
36
- repetition_penalty: gr.Number
37
- no_repeat_ngram_size: gr.Number
38
- prefix: gr.Textbox
39
- suppress_blank: gr.Checkbox
40
- suppress_tokens: gr.Textbox
41
- max_initial_timestamp: gr.Number
42
- word_timestamps: gr.Checkbox
43
- prepend_punctuations: gr.Textbox
44
- append_punctuations: gr.Textbox
45
- max_new_tokens: gr.Number
46
- chunk_length: gr.Number
47
- hallucination_silence_threshold: gr.Number
48
- hotwords: gr.Textbox
49
- language_detection_threshold: gr.Number
50
- language_detection_segments: gr.Number
51
- is_bgm_separate: gr.Checkbox
52
- uvr_model_size: gr.Dropdown
53
- uvr_device: gr.Dropdown
54
- uvr_segment_size: gr.Number
55
- uvr_save_file: gr.Checkbox
56
- uvr_enable_offload: gr.Checkbox
57
- """
58
- A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
59
- This data class is used to mitigate the key-value problem between Gradio components and function parameters.
60
- Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
61
- See more about Gradio pre-processing: https://www.gradio.app/docs/components
62
-
63
- Attributes
64
- ----------
65
- model_size: gr.Dropdown
66
- Whisper model size.
67
-
68
- lang: gr.Dropdown
69
- Source language of the file to transcribe.
70
-
71
- is_translate: gr.Checkbox
72
- Boolean value that determines whether to translate to English.
73
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
74
-
75
- beam_size: gr.Number
76
- Int value that is used for decoding option.
77
-
78
- log_prob_threshold: gr.Number
79
- If the average log probability over sampled tokens is below this value, treat as failed.
80
-
81
- no_speech_threshold: gr.Number
82
- If the no_speech probability is higher than this value AND
83
- the average log probability over sampled tokens is below `log_prob_threshold`,
84
- consider the segment as silent.
85
-
86
- compute_type: gr.Dropdown
87
- compute type for transcription.
88
- see more info : https://opennmt.net/CTranslate2/quantization.html
89
-
90
- best_of: gr.Number
91
- Number of candidates when sampling with non-zero temperature.
92
-
93
- patience: gr.Number
94
- Beam search patience factor.
95
-
96
- condition_on_previous_text: gr.Checkbox
97
- if True, the previous output of the model is provided as a prompt for the next window;
98
- disabling may make the text inconsistent across windows, but the model becomes less prone to
99
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
100
-
101
- initial_prompt: gr.Textbox
102
- Optional text to provide as a prompt for the first window. This can be used to provide, or
103
- "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
104
- to make it more likely to predict those word correctly.
105
-
106
- temperature: gr.Slider
107
- Temperature for sampling. It can be a tuple of temperatures,
108
- which will be successively used upon failures according to either
109
- `compression_ratio_threshold` or `log_prob_threshold`.
110
-
111
- compression_ratio_threshold: gr.Number
112
- If the gzip compression ratio is above this value, treat as failed
113
-
114
- vad_filter: gr.Checkbox
115
- Enable the voice activity detection (VAD) to filter out parts of the audio
116
- without speech. This step is using the Silero VAD model
117
- https://github.com/snakers4/silero-vad.
118
-
119
- threshold: gr.Slider
120
- This parameter is related with Silero VAD. Speech threshold.
121
- Silero VAD outputs speech probabilities for each audio chunk,
122
- probabilities ABOVE this value are considered as SPEECH. It is better to tune this
123
- parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
124
-
125
- min_speech_duration_ms: gr.Number
126
- This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
127
-
128
- max_speech_duration_s: gr.Number
129
- This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
130
- than max_speech_duration_s will be split at the timestamp of the last silence that
131
- lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
132
- split aggressively just before max_speech_duration_s.
133
-
134
- min_silence_duration_ms: gr.Number
135
- This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
136
- before separating it
137
-
138
- speech_pad_ms: gr.Number
139
- This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
140
-
141
- batch_size: gr.Number
142
- This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
143
-
144
- is_diarize: gr.Checkbox
145
- This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
146
-
147
- hf_token: gr.Textbox
148
- This parameter is related with whisperx. Huggingface token is needed to download diarization models.
149
- Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
150
-
151
- diarization_device: gr.Dropdown
152
- This parameter is related with whisperx. Device to run diarization model
153
-
154
- length_penalty: gr.Number
155
- This parameter is related to faster-whisper. Exponential length penalty constant.
156
-
157
- repetition_penalty: gr.Number
158
- This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
159
- (set > 1 to penalize).
160
-
161
- no_repeat_ngram_size: gr.Number
162
- This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
163
-
164
- prefix: gr.Textbox
165
- This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
166
-
167
- suppress_blank: gr.Checkbox
168
- This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
169
-
170
- suppress_tokens: gr.Textbox
171
- This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
172
- of symbols as defined in the model config.json file.
173
-
174
- max_initial_timestamp: gr.Number
175
- This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
176
-
177
- word_timestamps: gr.Checkbox
178
- This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
179
- and dynamic time warping, and include the timestamps for each word in each segment.
180
-
181
- prepend_punctuations: gr.Textbox
182
- This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
183
- with the next word.
184
-
185
- append_punctuations: gr.Textbox
186
- This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
187
- with the previous word.
188
-
189
- max_new_tokens: gr.Number
190
- This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
191
- the maximum will be set by the default max_length.
192
-
193
- chunk_length: gr.Number
194
- This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
195
- If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
196
-
197
- hallucination_silence_threshold: gr.Number
198
- This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
199
- (in seconds) when a possible hallucination is detected.
200
-
201
- hotwords: gr.Textbox
202
- This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
203
-
204
- language_detection_threshold: gr.Number
205
- This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
206
-
207
- language_detection_segments: gr.Number
208
- This parameter is related to faster-whisper. Number of segments to consider for the language detection.
209
-
210
- is_separate_bgm: gr.Checkbox
211
- This parameter is related to UVR. Boolean value that determines whether to separate bgm or not.
212
-
213
- uvr_model_size: gr.Dropdown
214
- This parameter is related to UVR. UVR model size.
215
-
216
- uvr_device: gr.Dropdown
217
- This parameter is related to UVR. Device to run UVR model.
218
-
219
- uvr_segment_size: gr.Number
220
- This parameter is related to UVR. Segment size for UVR model.
221
-
222
- uvr_save_file: gr.Checkbox
223
- This parameter is related to UVR. Boolean value that determines whether to save the file or not.
224
-
225
- uvr_enable_offload: gr.Checkbox
226
- This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
227
- after each transcription.
228
- """
229
-
230
- def as_list(self) -> list:
231
- """
232
- Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
233
- See more about Gradio pre-processing: : https://www.gradio.app/docs/components
234
-
235
- Returns
236
- ----------
237
- A list of Gradio components
238
- """
239
- return [getattr(self, f.name) for f in fields(self)]
240
-
241
- @staticmethod
242
- def as_value(*args) -> 'WhisperValues':
243
- """
244
- To use Whisper parameters in function after Gradio post-processing.
245
- See more about Gradio post-processing: : https://www.gradio.app/docs/components
246
-
247
- Returns
248
- ----------
249
- WhisperValues
250
- Data class that has values of parameters
251
- """
252
- return WhisperValues(*args)
253
-
254
-
255
- @dataclass
256
- class WhisperValues:
257
- model_size: str = "large-v2"
258
- lang: Optional[str] = None
259
- is_translate: bool = False
260
- beam_size: int = 5
261
- log_prob_threshold: float = -1.0
262
- no_speech_threshold: float = 0.6
263
- compute_type: str = "float16"
264
- best_of: int = 5
265
- patience: float = 1.0
266
- condition_on_previous_text: bool = True
267
- prompt_reset_on_temperature: float = 0.5
268
- initial_prompt: Optional[str] = None
269
- temperature: float = 0.0
270
- compression_ratio_threshold: float = 2.4
271
- vad_filter: bool = False
272
- threshold: float = 0.5
273
- min_speech_duration_ms: int = 250
274
- max_speech_duration_s: float = float("inf")
275
- min_silence_duration_ms: int = 2000
276
- speech_pad_ms: int = 400
277
- batch_size: int = 24
278
- is_diarize: bool = False
279
- hf_token: str = ""
280
- diarization_device: str = "cuda"
281
- length_penalty: float = 1.0
282
- repetition_penalty: float = 1.0
283
- no_repeat_ngram_size: int = 0
284
- prefix: Optional[str] = None
285
- suppress_blank: bool = True
286
- suppress_tokens: Optional[str] = "[-1]"
287
- max_initial_timestamp: float = 0.0
288
- word_timestamps: bool = False
289
- prepend_punctuations: Optional[str] = "\"'“¿([{-"
290
- append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
291
- max_new_tokens: Optional[int] = None
292
- chunk_length: Optional[int] = 30
293
- hallucination_silence_threshold: Optional[float] = None
294
- hotwords: Optional[str] = None
295
- language_detection_threshold: Optional[float] = None
296
- language_detection_segments: int = 1
297
- is_bgm_separate: bool = False
298
- uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
299
- uvr_device: str = "cuda"
300
- uvr_segment_size: int = 256
301
- uvr_save_file: bool = False
302
- uvr_enable_offload: bool = True
303
- """
304
- A data class to use Whisper parameters.
305
- """
306
-
307
- def to_yaml(self) -> Dict:
308
- data = {
309
- "whisper": {
310
- "model_size": self.model_size,
311
- "lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang,
312
- "is_translate": self.is_translate,
313
- "beam_size": self.beam_size,
314
- "log_prob_threshold": self.log_prob_threshold,
315
- "no_speech_threshold": self.no_speech_threshold,
316
- "best_of": self.best_of,
317
- "patience": self.patience,
318
- "condition_on_previous_text": self.condition_on_previous_text,
319
- "prompt_reset_on_temperature": self.prompt_reset_on_temperature,
320
- "initial_prompt": None if not self.initial_prompt else self.initial_prompt,
321
- "temperature": self.temperature,
322
- "compression_ratio_threshold": self.compression_ratio_threshold,
323
- "batch_size": self.batch_size,
324
- "length_penalty": self.length_penalty,
325
- "repetition_penalty": self.repetition_penalty,
326
- "no_repeat_ngram_size": self.no_repeat_ngram_size,
327
- "prefix": None if not self.prefix else self.prefix,
328
- "suppress_blank": self.suppress_blank,
329
- "suppress_tokens": self.suppress_tokens,
330
- "max_initial_timestamp": self.max_initial_timestamp,
331
- "word_timestamps": self.word_timestamps,
332
- "prepend_punctuations": self.prepend_punctuations,
333
- "append_punctuations": self.append_punctuations,
334
- "max_new_tokens": self.max_new_tokens,
335
- "chunk_length": self.chunk_length,
336
- "hallucination_silence_threshold": self.hallucination_silence_threshold,
337
- "hotwords": None if not self.hotwords else self.hotwords,
338
- "language_detection_threshold": self.language_detection_threshold,
339
- "language_detection_segments": self.language_detection_segments,
340
- },
341
- "vad": {
342
- "vad_filter": self.vad_filter,
343
- "threshold": self.threshold,
344
- "min_speech_duration_ms": self.min_speech_duration_ms,
345
- "max_speech_duration_s": self.max_speech_duration_s,
346
- "min_silence_duration_ms": self.min_silence_duration_ms,
347
- "speech_pad_ms": self.speech_pad_ms,
348
- },
349
- "diarization": {
350
- "is_diarize": self.is_diarize,
351
- "hf_token": self.hf_token
352
- },
353
- "bgm_separation": {
354
- "is_separate_bgm": self.is_bgm_separate,
355
- "model_size": self.uvr_model_size,
356
- "segment_size": self.uvr_segment_size,
357
- "save_file": self.uvr_save_file,
358
- "enable_offload": self.uvr_enable_offload
359
- },
360
- }
361
- return data
362
-
363
- def as_list(self) -> list:
364
- """
365
- Converts the data class attributes into a list
366
-
367
- Returns
368
- ----------
369
- A list of Whisper parameters
370
- """
371
- return [getattr(self, f.name) for f in fields(self)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_bgm_separation.py CHANGED
@@ -1,6 +1,6 @@
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
- from modules.whisper.whisper_parameter import WhisperValues
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
@@ -17,9 +17,9 @@ import os
17
  @pytest.mark.parametrize(
18
  "whisper_type,vad_filter,bgm_separation,diarization",
19
  [
20
- ("whisper", False, True, False),
21
- ("faster-whisper", False, True, False),
22
- ("insanely_fast_whisper", False, True, False)
23
  ]
24
  )
25
  def test_bgm_separation_pipeline(
@@ -38,9 +38,9 @@ def test_bgm_separation_pipeline(
38
  @pytest.mark.parametrize(
39
  "whisper_type,vad_filter,bgm_separation,diarization",
40
  [
41
- ("whisper", True, True, False),
42
- ("faster-whisper", True, True, False),
43
- ("insanely_fast_whisper", True, True, False)
44
  ]
45
  )
46
  def test_bgm_separation_with_vad_pipeline(
 
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.data_classes import *
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
 
17
  @pytest.mark.parametrize(
18
  "whisper_type,vad_filter,bgm_separation,diarization",
19
  [
20
+ (WhisperImpl.WHISPER.value, False, True, False),
21
+ (WhisperImpl.FASTER_WHISPER.value, False, True, False),
22
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False)
23
  ]
24
  )
25
  def test_bgm_separation_pipeline(
 
38
  @pytest.mark.parametrize(
39
  "whisper_type,vad_filter,bgm_separation,diarization",
40
  [
41
+ (WhisperImpl.WHISPER.value, True, True, False),
42
+ (WhisperImpl.FASTER_WHISPER.value, True, True, False),
43
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False)
44
  ]
45
  )
46
  def test_bgm_separation_with_vad_pipeline(
tests/test_config.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
7
  TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
8
  TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
9
- TEST_WHISPER_MODEL = "tiny"
10
  TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
11
  TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
12
  TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
 
6
  TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
7
  TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
8
  TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
9
+ TEST_WHISPER_MODEL = "tiny.en"
10
  TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
11
  TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
12
  TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
tests/test_diarization.py CHANGED
@@ -1,6 +1,6 @@
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
- from modules.whisper.whisper_parameter import WhisperValues
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
@@ -16,9 +16,9 @@ import os
16
  @pytest.mark.parametrize(
17
  "whisper_type,vad_filter,bgm_separation,diarization",
18
  [
19
- ("whisper", False, False, True),
20
- ("faster-whisper", False, False, True),
21
- ("insanely_fast_whisper", False, False, True)
22
  ]
23
  )
24
  def test_diarization_pipeline(
 
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.data_classes import *
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
 
16
  @pytest.mark.parametrize(
17
  "whisper_type,vad_filter,bgm_separation,diarization",
18
  [
19
+ (WhisperImpl.WHISPER.value, False, False, True),
20
+ (WhisperImpl.FASTER_WHISPER.value, False, False, True),
21
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True)
22
  ]
23
  )
24
  def test_diarization_pipeline(
tests/test_transcription.py CHANGED
@@ -1,5 +1,5 @@
1
  from modules.whisper.whisper_factory import WhisperFactory
2
- from modules.whisper.whisper_parameter import WhisperValues
3
  from modules.utils.paths import WEBUI_DIR
4
  from test_config import *
5
 
@@ -12,9 +12,9 @@ import os
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
- ("whisper", False, False, False),
16
- ("faster-whisper", False, False, False),
17
- ("insanely_fast_whisper", False, False, False)
18
  ]
19
  )
20
  def test_transcribe(
@@ -37,14 +37,22 @@ def test_transcribe(
37
  f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
38
  )
39
 
40
- hparams = WhisperValues(
41
- model_size=TEST_WHISPER_MODEL,
42
- vad_filter=vad_filter,
43
- is_bgm_separate=bgm_separation,
44
- compute_type=whisper_inferencer.current_compute_type,
45
- uvr_enable_offload=True,
46
- is_diarize=diarization,
47
- ).as_list()
 
 
 
 
 
 
 
 
48
 
49
  subtitle_str, file_path = whisper_inferencer.transcribe_file(
50
  [audio_path],
 
1
  from modules.whisper.whisper_factory import WhisperFactory
2
+ from modules.whisper.data_classes import *
3
  from modules.utils.paths import WEBUI_DIR
4
  from test_config import *
5
 
 
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
+ (WhisperImpl.WHISPER.value, False, False, False),
16
+ (WhisperImpl.FASTER_WHISPER.value, False, False, False),
17
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
18
  ]
19
  )
20
  def test_transcribe(
 
37
  f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
38
  )
39
 
40
+ hparams = TranscriptionPipelineParams(
41
+ whisper=WhisperParams(
42
+ model_size=TEST_WHISPER_MODEL,
43
+ compute_type=whisper_inferencer.current_compute_type
44
+ ),
45
+ vad=VadParams(
46
+ vad_filter=vad_filter
47
+ ),
48
+ bgm_separation=BGMSeparationParams(
49
+ is_separate_bgm=bgm_separation,
50
+ enable_offload=True
51
+ ),
52
+ diarization=DiarizationParams(
53
+ is_diarize=diarization
54
+ ),
55
+ ).to_list()
56
 
57
  subtitle_str, file_path = whisper_inferencer.transcribe_file(
58
  [audio_path],
tests/test_vad.py CHANGED
@@ -1,6 +1,6 @@
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
- from modules.whisper.whisper_parameter import WhisperValues
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
@@ -12,9 +12,9 @@ import os
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
- ("whisper", True, False, False),
16
- ("faster-whisper", True, False, False),
17
- ("insanely_fast_whisper", True, False, False)
18
  ]
19
  )
20
  def test_vad_pipeline(
 
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.data_classes import *
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
 
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
+ (WhisperImpl.WHISPER.value, True, False, False),
16
+ (WhisperImpl.FASTER_WHISPER.value, True, False, False),
17
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False)
18
  ]
19
  )
20
  def test_vad_pipeline(