Spaces:
Running
Running
Merge pull request #363 from jhj0517/feature/refactor-models
Browse files- app.py +25 -156
- modules/diarize/diarize_pipeline.py +1 -0
- modules/diarize/diarizer.py +16 -8
- modules/translation/translation_base.py +1 -1
- modules/utils/constants.py +3 -0
- modules/utils/subtitle_manager.py +11 -0
- modules/vad/silero_vad.py +6 -5
- modules/whisper/{whisper_base.py → base_transcription_pipeline.py} +91 -51
- modules/whisper/data_classes.py +565 -0
- modules/whisper/faster_whisper_inference.py +12 -22
- modules/whisper/insanely_fast_whisper_inference.py +18 -10
- modules/whisper/whisper_Inference.py +28 -21
- modules/whisper/whisper_factory.py +8 -14
- modules/whisper/whisper_parameter.py +0 -371
- tests/test_bgm_separation.py +7 -7
- tests/test_config.py +1 -1
- tests/test_diarization.py +4 -4
- tests/test_transcription.py +20 -12
- tests/test_vad.py +4 -4
app.py
CHANGED
@@ -7,17 +7,14 @@ import yaml
|
|
7 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
|
8 |
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
9 |
UVR_MODELS_DIR, I18N_YAML_PATH)
|
10 |
-
from modules.utils.constants import AUTOMATIC_DETECTION
|
11 |
from modules.utils.files_manager import load_yaml
|
12 |
from modules.whisper.whisper_factory import WhisperFactory
|
13 |
-
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
14 |
-
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
15 |
from modules.translation.nllb_inference import NLLBInference
|
16 |
from modules.ui.htmls import *
|
17 |
from modules.utils.cli_manager import str2bool
|
18 |
from modules.utils.youtube_manager import get_ytmetas
|
19 |
from modules.translation.deepl_api import DeepLAPI
|
20 |
-
from modules.whisper.
|
21 |
|
22 |
|
23 |
class App:
|
@@ -44,7 +41,7 @@ class App:
|
|
44 |
print(f"Use \"{self.args.whisper_type}\" implementation\n"
|
45 |
f"Device \"{self.whisper_inf.device}\" is detected")
|
46 |
|
47 |
-
def
|
48 |
whisper_params = self.default_params["whisper"]
|
49 |
vad_params = self.default_params["vad"]
|
50 |
diarization_params = self.default_params["diarization"]
|
@@ -66,158 +63,31 @@ class App:
|
|
66 |
interactive=True)
|
67 |
|
68 |
with gr.Accordion(_("Advanced Parameters"), open=False):
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
value=whisper_params["log_prob_threshold"], interactive=True,
|
74 |
-
info="If the average log probability over sampled tokens is below this value, treat as failed.")
|
75 |
-
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"],
|
76 |
-
interactive=True,
|
77 |
-
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
78 |
-
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
79 |
-
value=self.whisper_inf.current_compute_type, interactive=True,
|
80 |
-
allow_custom_value=True,
|
81 |
-
info="Select the type of computation to perform.")
|
82 |
-
nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
|
83 |
-
info="Number of candidates when sampling with non-zero temperature.")
|
84 |
-
nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
|
85 |
-
info="Beam search patience factor.")
|
86 |
-
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text",
|
87 |
-
value=whisper_params["condition_on_previous_text"],
|
88 |
-
interactive=True,
|
89 |
-
info="Condition on previous text during decoding.")
|
90 |
-
sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature",
|
91 |
-
value=whisper_params["prompt_reset_on_temperature"],
|
92 |
-
minimum=0, maximum=1, step=0.01, interactive=True,
|
93 |
-
info="Resets prompt if temperature is above this value."
|
94 |
-
" Arg has effect only if 'Condition On Previous Text' is True.")
|
95 |
-
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
|
96 |
-
info="Initial prompt to use for decoding.")
|
97 |
-
sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
|
98 |
-
step=0.01, maximum=1.0, interactive=True,
|
99 |
-
info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
|
100 |
-
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold",
|
101 |
-
value=whisper_params["compression_ratio_threshold"],
|
102 |
-
interactive=True,
|
103 |
-
info="If the gzip compression ratio is above this value, treat as failed.")
|
104 |
-
nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
|
105 |
-
precision=0,
|
106 |
-
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
107 |
-
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
108 |
-
nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
|
109 |
-
info="Exponential length penalty constant.")
|
110 |
-
nb_repetition_penalty = gr.Number(label="Repetition Penalty",
|
111 |
-
value=whisper_params["repetition_penalty"],
|
112 |
-
info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
|
113 |
-
nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size",
|
114 |
-
value=whisper_params["no_repeat_ngram_size"],
|
115 |
-
precision=0,
|
116 |
-
info="Prevent repetitions of n-grams with this size (set 0 to disable).")
|
117 |
-
tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
|
118 |
-
info="Optional text to provide as a prefix for the first window.")
|
119 |
-
cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
|
120 |
-
info="Suppress blank outputs at the beginning of the sampling.")
|
121 |
-
tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
|
122 |
-
info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
|
123 |
-
nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp",
|
124 |
-
value=whisper_params["max_initial_timestamp"],
|
125 |
-
info="The initial timestamp cannot be later than this.")
|
126 |
-
cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
|
127 |
-
info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
|
128 |
-
tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations",
|
129 |
-
value=whisper_params["prepend_punctuations"],
|
130 |
-
info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
|
131 |
-
tb_append_punctuations = gr.Textbox(label="Append Punctuations",
|
132 |
-
value=whisper_params["append_punctuations"],
|
133 |
-
info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
|
134 |
-
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
|
135 |
-
precision=0,
|
136 |
-
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
137 |
-
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
138 |
-
value=lambda: whisper_params[
|
139 |
-
"hallucination_silence_threshold"],
|
140 |
-
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
141 |
-
tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
|
142 |
-
info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
|
143 |
-
nb_language_detection_threshold = gr.Number(label="Language Detection Threshold",
|
144 |
-
value=lambda: whisper_params[
|
145 |
-
"language_detection_threshold"],
|
146 |
-
info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
|
147 |
-
nb_language_detection_segments = gr.Number(label="Language Detection Segments",
|
148 |
-
value=lambda: whisper_params["language_detection_segments"],
|
149 |
-
precision=0,
|
150 |
-
info="Number of segments to consider for the language detection.")
|
151 |
-
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
152 |
-
nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
|
153 |
|
154 |
with gr.Accordion(_("Background Music Remover Filter"), open=False):
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device,
|
160 |
-
choices=self.whisper_inf.music_separator.available_devices)
|
161 |
-
dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"],
|
162 |
-
choices=self.whisper_inf.music_separator.available_models)
|
163 |
-
nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
|
164 |
-
cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"])
|
165 |
-
cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"),
|
166 |
-
value=uvr_params["enable_offload"])
|
167 |
|
168 |
with gr.Accordion(_("Voice Detection Filter"), open=False):
|
169 |
-
|
170 |
-
interactive=True,
|
171 |
-
info=_("Enable this to transcribe only detected voice"))
|
172 |
-
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
|
173 |
-
value=vad_params["threshold"],
|
174 |
-
info="Lower it to be more sensitive to small sounds.")
|
175 |
-
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
|
176 |
-
value=vad_params["min_speech_duration_ms"],
|
177 |
-
info="Final speech chunks shorter than this time are thrown out")
|
178 |
-
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
|
179 |
-
value=vad_params["max_speech_duration_s"],
|
180 |
-
info="Maximum duration of speech chunks in \"seconds\".")
|
181 |
-
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
|
182 |
-
value=vad_params["min_silence_duration_ms"],
|
183 |
-
info="In the end of each speech chunk wait for this time"
|
184 |
-
" before separating it")
|
185 |
-
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
|
186 |
-
info="Final speech chunks are padded by this time each side")
|
187 |
|
188 |
with gr.Accordion(_("Diarization"), open=False):
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
dd_diarization_device = gr.Dropdown(label=_("Device"),
|
193 |
-
choices=self.whisper_inf.diarizer.get_available_device(),
|
194 |
-
value=self.whisper_inf.diarizer.get_device())
|
195 |
|
196 |
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
197 |
|
|
|
|
|
198 |
return (
|
199 |
-
|
200 |
-
model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
|
201 |
-
log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
|
202 |
-
compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
|
203 |
-
condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
|
204 |
-
temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
|
205 |
-
vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
|
206 |
-
max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
|
207 |
-
speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
|
208 |
-
is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
|
209 |
-
length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
|
210 |
-
no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
|
211 |
-
suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
|
212 |
-
word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
|
213 |
-
append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
|
214 |
-
hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
|
215 |
-
language_detection_threshold=nb_language_detection_threshold,
|
216 |
-
language_detection_segments=nb_language_detection_segments,
|
217 |
-
prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
|
218 |
-
uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
|
219 |
-
uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
|
220 |
-
),
|
221 |
dd_file_format,
|
222 |
cb_timestamp
|
223 |
)
|
@@ -243,7 +113,7 @@ class App:
|
|
243 |
visible=self.args.colab,
|
244 |
value="")
|
245 |
|
246 |
-
|
247 |
|
248 |
with gr.Row():
|
249 |
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
@@ -254,7 +124,7 @@ class App:
|
|
254 |
|
255 |
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
|
256 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
257 |
-
inputs=params +
|
258 |
outputs=[tb_indicator, files_subtitles])
|
259 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
260 |
|
@@ -268,7 +138,7 @@ class App:
|
|
268 |
tb_title = gr.Label(label=_("Youtube Title"))
|
269 |
tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
|
270 |
|
271 |
-
|
272 |
|
273 |
with gr.Row():
|
274 |
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
@@ -280,7 +150,7 @@ class App:
|
|
280 |
params = [tb_youtubelink, dd_file_format, cb_timestamp]
|
281 |
|
282 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
283 |
-
inputs=params +
|
284 |
outputs=[tb_indicator, files_subtitles])
|
285 |
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
286 |
outputs=[img_thumbnail, tb_title, tb_description])
|
@@ -290,7 +160,7 @@ class App:
|
|
290 |
with gr.Row():
|
291 |
mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
|
292 |
|
293 |
-
|
294 |
|
295 |
with gr.Row():
|
296 |
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
@@ -302,7 +172,7 @@ class App:
|
|
302 |
params = [mic_input, dd_file_format, cb_timestamp]
|
303 |
|
304 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
305 |
-
inputs=params +
|
306 |
outputs=[tb_indicator, files_subtitles])
|
307 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
308 |
|
@@ -417,7 +287,6 @@ class App:
|
|
417 |
|
418 |
# Launch the app with optional gradio settings
|
419 |
args = self.args
|
420 |
-
|
421 |
self.app.queue(
|
422 |
api_open=args.api_open
|
423 |
).launch(
|
@@ -447,8 +316,8 @@ class App:
|
|
447 |
|
448 |
|
449 |
parser = argparse.ArgumentParser()
|
450 |
-
parser.add_argument('--whisper_type', type=str, default=
|
451 |
-
choices=[
|
452 |
help='A type of the whisper implementation (Github repo name)')
|
453 |
parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
|
454 |
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
|
|
|
7 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
|
8 |
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
9 |
UVR_MODELS_DIR, I18N_YAML_PATH)
|
|
|
10 |
from modules.utils.files_manager import load_yaml
|
11 |
from modules.whisper.whisper_factory import WhisperFactory
|
|
|
|
|
12 |
from modules.translation.nllb_inference import NLLBInference
|
13 |
from modules.ui.htmls import *
|
14 |
from modules.utils.cli_manager import str2bool
|
15 |
from modules.utils.youtube_manager import get_ytmetas
|
16 |
from modules.translation.deepl_api import DeepLAPI
|
17 |
+
from modules.whisper.data_classes import *
|
18 |
|
19 |
|
20 |
class App:
|
|
|
41 |
print(f"Use \"{self.args.whisper_type}\" implementation\n"
|
42 |
f"Device \"{self.whisper_inf.device}\" is detected")
|
43 |
|
44 |
+
def create_pipeline_inputs(self):
|
45 |
whisper_params = self.default_params["whisper"]
|
46 |
vad_params = self.default_params["vad"]
|
47 |
diarization_params = self.default_params["diarization"]
|
|
|
63 |
interactive=True)
|
64 |
|
65 |
with gr.Accordion(_("Advanced Parameters"), open=False):
|
66 |
+
whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True,
|
67 |
+
whisper_type=self.args.whisper_type,
|
68 |
+
available_compute_types=self.whisper_inf.available_compute_types,
|
69 |
+
compute_type=self.whisper_inf.current_compute_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
with gr.Accordion(_("Background Music Remover Filter"), open=False):
|
72 |
+
uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params,
|
73 |
+
available_models=self.whisper_inf.music_separator.available_models,
|
74 |
+
available_devices=self.whisper_inf.music_separator.available_devices,
|
75 |
+
device=self.whisper_inf.music_separator.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
with gr.Accordion(_("Voice Detection Filter"), open=False):
|
78 |
+
vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
with gr.Accordion(_("Diarization"), open=False):
|
81 |
+
diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params,
|
82 |
+
available_devices=self.whisper_inf.diarizer.available_device,
|
83 |
+
device=self.whisper_inf.diarizer.device)
|
|
|
|
|
|
|
84 |
|
85 |
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
86 |
|
87 |
+
pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs
|
88 |
+
|
89 |
return (
|
90 |
+
pipeline_inputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
dd_file_format,
|
92 |
cb_timestamp
|
93 |
)
|
|
|
113 |
visible=self.args.colab,
|
114 |
value="")
|
115 |
|
116 |
+
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
117 |
|
118 |
with gr.Row():
|
119 |
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
|
|
124 |
|
125 |
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
|
126 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
127 |
+
inputs=params + pipeline_params,
|
128 |
outputs=[tb_indicator, files_subtitles])
|
129 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
130 |
|
|
|
138 |
tb_title = gr.Label(label=_("Youtube Title"))
|
139 |
tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
|
140 |
|
141 |
+
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
142 |
|
143 |
with gr.Row():
|
144 |
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
|
|
150 |
params = [tb_youtubelink, dd_file_format, cb_timestamp]
|
151 |
|
152 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
153 |
+
inputs=params + pipeline_params,
|
154 |
outputs=[tb_indicator, files_subtitles])
|
155 |
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
156 |
outputs=[img_thumbnail, tb_title, tb_description])
|
|
|
160 |
with gr.Row():
|
161 |
mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
|
162 |
|
163 |
+
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
164 |
|
165 |
with gr.Row():
|
166 |
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
|
|
172 |
params = [mic_input, dd_file_format, cb_timestamp]
|
173 |
|
174 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
175 |
+
inputs=params + pipeline_params,
|
176 |
outputs=[tb_indicator, files_subtitles])
|
177 |
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
178 |
|
|
|
287 |
|
288 |
# Launch the app with optional gradio settings
|
289 |
args = self.args
|
|
|
290 |
self.app.queue(
|
291 |
api_open=args.api_open
|
292 |
).launch(
|
|
|
316 |
|
317 |
|
318 |
parser = argparse.ArgumentParser()
|
319 |
+
parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value,
|
320 |
+
choices=[item.value for item in WhisperImpl],
|
321 |
help='A type of the whisper implementation (Github repo name)')
|
322 |
parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
|
323 |
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
|
modules/diarize/diarize_pipeline.py
CHANGED
@@ -44,6 +44,7 @@ class DiarizationPipeline:
|
|
44 |
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
45 |
transcript_segments = transcript_result["segments"]
|
46 |
for seg in transcript_segments:
|
|
|
47 |
# assign speaker to segment (if any)
|
48 |
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
49 |
seg['start'])
|
|
|
44 |
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
45 |
transcript_segments = transcript_result["segments"]
|
46 |
for seg in transcript_segments:
|
47 |
+
seg = seg.dict()
|
48 |
# assign speaker to segment (if any)
|
49 |
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
50 |
seg['start'])
|
modules/diarize/diarizer.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
-
from typing import List, Union, BinaryIO, Optional
|
4 |
import numpy as np
|
5 |
import time
|
6 |
import logging
|
@@ -8,6 +8,7 @@ import logging
|
|
8 |
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
9 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
10 |
from modules.diarize.audio_loader import load_audio
|
|
|
11 |
|
12 |
|
13 |
class Diarizer:
|
@@ -23,10 +24,10 @@ class Diarizer:
|
|
23 |
|
24 |
def run(self,
|
25 |
audio: Union[str, BinaryIO, np.ndarray],
|
26 |
-
transcribed_result: List[
|
27 |
use_auth_token: str,
|
28 |
device: Optional[str] = None
|
29 |
-
):
|
30 |
"""
|
31 |
Diarize transcribed result as a post-processing
|
32 |
|
@@ -34,7 +35,7 @@ class Diarizer:
|
|
34 |
----------
|
35 |
audio: Union[str, BinaryIO, np.ndarray]
|
36 |
Audio input. This can be file path or binary type.
|
37 |
-
transcribed_result: List[
|
38 |
transcribed result through whisper.
|
39 |
use_auth_token: str
|
40 |
Huggingface token with READ permission. This is only needed the first time you download the model.
|
@@ -44,8 +45,8 @@ class Diarizer:
|
|
44 |
|
45 |
Returns
|
46 |
----------
|
47 |
-
segments_result: List[
|
48 |
-
list of
|
49 |
elapsed_time: float
|
50 |
elapsed time for running
|
51 |
"""
|
@@ -68,14 +69,21 @@ class Diarizer:
|
|
68 |
{"segments": transcribed_result}
|
69 |
)
|
70 |
|
|
|
71 |
for segment in diarized_result["segments"]:
|
|
|
72 |
speaker = "None"
|
73 |
if "speaker" in segment:
|
74 |
speaker = segment["speaker"]
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
elapsed_time = time.time() - start_time
|
78 |
-
return
|
79 |
|
80 |
def update_pipe(self,
|
81 |
use_auth_token: str,
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
from typing import List, Union, BinaryIO, Optional, Tuple
|
4 |
import numpy as np
|
5 |
import time
|
6 |
import logging
|
|
|
8 |
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
9 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
10 |
from modules.diarize.audio_loader import load_audio
|
11 |
+
from modules.whisper.data_classes import *
|
12 |
|
13 |
|
14 |
class Diarizer:
|
|
|
24 |
|
25 |
def run(self,
|
26 |
audio: Union[str, BinaryIO, np.ndarray],
|
27 |
+
transcribed_result: List[Segment],
|
28 |
use_auth_token: str,
|
29 |
device: Optional[str] = None
|
30 |
+
) -> Tuple[List[Segment], float]:
|
31 |
"""
|
32 |
Diarize transcribed result as a post-processing
|
33 |
|
|
|
35 |
----------
|
36 |
audio: Union[str, BinaryIO, np.ndarray]
|
37 |
Audio input. This can be file path or binary type.
|
38 |
+
transcribed_result: List[Segment]
|
39 |
transcribed result through whisper.
|
40 |
use_auth_token: str
|
41 |
Huggingface token with READ permission. This is only needed the first time you download the model.
|
|
|
45 |
|
46 |
Returns
|
47 |
----------
|
48 |
+
segments_result: List[Segment]
|
49 |
+
list of Segment that includes start, end timestamps and transcribed text
|
50 |
elapsed_time: float
|
51 |
elapsed time for running
|
52 |
"""
|
|
|
69 |
{"segments": transcribed_result}
|
70 |
)
|
71 |
|
72 |
+
segments_result = []
|
73 |
for segment in diarized_result["segments"]:
|
74 |
+
segment = segment.dict()
|
75 |
speaker = "None"
|
76 |
if "speaker" in segment:
|
77 |
speaker = segment["speaker"]
|
78 |
+
diarized_text = speaker + "|" + segment["text"].strip()
|
79 |
+
segments_result.append(Segment(
|
80 |
+
start=segment["start"],
|
81 |
+
end=segment["end"],
|
82 |
+
text=diarized_text
|
83 |
+
))
|
84 |
|
85 |
elapsed_time = time.time() - start_time
|
86 |
+
return segments_result, elapsed_time
|
87 |
|
88 |
def update_pipe(self,
|
89 |
use_auth_token: str,
|
modules/translation/translation_base.py
CHANGED
@@ -6,7 +6,7 @@ from typing import List
|
|
6 |
from datetime import datetime
|
7 |
|
8 |
import modules.translation.nllb_inference as nllb
|
9 |
-
from modules.whisper.
|
10 |
from modules.utils.subtitle_manager import *
|
11 |
from modules.utils.files_manager import load_yaml, save_yaml
|
12 |
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
|
|
6 |
from datetime import datetime
|
7 |
|
8 |
import modules.translation.nllb_inference as nllb
|
9 |
+
from modules.whisper.data_classes import *
|
10 |
from modules.utils.subtitle_manager import *
|
11 |
from modules.utils.files_manager import load_yaml, save_yaml
|
12 |
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
modules/utils/constants.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1 |
from gradio_i18n import Translate, gettext as _
|
2 |
|
3 |
AUTOMATIC_DETECTION = _("Automatic Detection")
|
|
|
|
|
|
|
|
1 |
from gradio_i18n import Translate, gettext as _
|
2 |
|
3 |
AUTOMATIC_DETECTION = _("Automatic Detection")
|
4 |
+
GRADIO_NONE_STR = ""
|
5 |
+
GRADIO_NONE_NUMBER_MAX = 9999
|
6 |
+
GRADIO_NONE_NUMBER_MIN = 0
|
modules/utils/subtitle_manager.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import re
|
2 |
|
|
|
|
|
3 |
|
4 |
def timeformat_srt(time):
|
5 |
hours = time // 3600
|
@@ -23,6 +25,9 @@ def write_file(subtitle, output_file):
|
|
23 |
|
24 |
|
25 |
def get_srt(segments):
|
|
|
|
|
|
|
26 |
output = ""
|
27 |
for i, segment in enumerate(segments):
|
28 |
output += f"{i + 1}\n"
|
@@ -34,6 +39,9 @@ def get_srt(segments):
|
|
34 |
|
35 |
|
36 |
def get_vtt(segments):
|
|
|
|
|
|
|
37 |
output = "WEBVTT\n\n"
|
38 |
for i, segment in enumerate(segments):
|
39 |
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
|
@@ -44,6 +52,9 @@ def get_vtt(segments):
|
|
44 |
|
45 |
|
46 |
def get_txt(segments):
|
|
|
|
|
|
|
47 |
output = ""
|
48 |
for i, segment in enumerate(segments):
|
49 |
if segment['text'].startswith(' '):
|
|
|
1 |
import re
|
2 |
|
3 |
+
from modules.whisper.data_classes import Segment
|
4 |
+
|
5 |
|
6 |
def timeformat_srt(time):
|
7 |
hours = time // 3600
|
|
|
25 |
|
26 |
|
27 |
def get_srt(segments):
|
28 |
+
if segments and isinstance(segments[0], Segment):
|
29 |
+
segments = [seg.dict() for seg in segments]
|
30 |
+
|
31 |
output = ""
|
32 |
for i, segment in enumerate(segments):
|
33 |
output += f"{i + 1}\n"
|
|
|
39 |
|
40 |
|
41 |
def get_vtt(segments):
|
42 |
+
if segments and isinstance(segments[0], Segment):
|
43 |
+
segments = [seg.dict() for seg in segments]
|
44 |
+
|
45 |
output = "WEBVTT\n\n"
|
46 |
for i, segment in enumerate(segments):
|
47 |
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
|
|
|
52 |
|
53 |
|
54 |
def get_txt(segments):
|
55 |
+
if segments and isinstance(segments[0], Segment):
|
56 |
+
segments = [seg.dict() for seg in segments]
|
57 |
+
|
58 |
output = ""
|
59 |
for i, segment in enumerate(segments):
|
60 |
if segment['text'].startswith(' '):
|
modules/vad/silero_vad.py
CHANGED
@@ -5,7 +5,8 @@ import numpy as np
|
|
5 |
from typing import BinaryIO, Union, List, Optional, Tuple
|
6 |
import warnings
|
7 |
import faster_whisper
|
8 |
-
from
|
|
|
9 |
import gradio as gr
|
10 |
|
11 |
|
@@ -247,18 +248,18 @@ class SileroVAD:
|
|
247 |
|
248 |
def restore_speech_timestamps(
|
249 |
self,
|
250 |
-
segments: List[
|
251 |
speech_chunks: List[dict],
|
252 |
sampling_rate: Optional[int] = None,
|
253 |
-
) -> List[
|
254 |
if sampling_rate is None:
|
255 |
sampling_rate = self.sampling_rate
|
256 |
|
257 |
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
258 |
|
259 |
for segment in segments:
|
260 |
-
segment
|
261 |
-
segment
|
262 |
|
263 |
return segments
|
264 |
|
|
|
5 |
from typing import BinaryIO, Union, List, Optional, Tuple
|
6 |
import warnings
|
7 |
import faster_whisper
|
8 |
+
from modules.whisper.data_classes import *
|
9 |
+
from faster_whisper.transcribe import SpeechTimestampsMap
|
10 |
import gradio as gr
|
11 |
|
12 |
|
|
|
248 |
|
249 |
def restore_speech_timestamps(
|
250 |
self,
|
251 |
+
segments: List[Segment],
|
252 |
speech_chunks: List[dict],
|
253 |
sampling_rate: Optional[int] = None,
|
254 |
+
) -> List[Segment]:
|
255 |
if sampling_rate is None:
|
256 |
sampling_rate = self.sampling_rate
|
257 |
|
258 |
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
259 |
|
260 |
for segment in segments:
|
261 |
+
segment.start = ts_map.get_original_time(segment.start)
|
262 |
+
segment.start = ts_map.get_original_time(segment.start)
|
263 |
|
264 |
return segments
|
265 |
|
modules/whisper/{whisper_base.py → base_transcription_pipeline.py}
RENAMED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import torch
|
|
|
3 |
import whisper
|
4 |
import ctranslate2
|
5 |
import gradio as gr
|
@@ -14,16 +15,16 @@ from dataclasses import astuple
|
|
14 |
from modules.uvr.music_separator import MusicSeparator
|
15 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
16 |
UVR_MODELS_DIR)
|
17 |
-
from modules.utils.constants import
|
18 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
19 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
20 |
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
21 |
-
from modules.whisper.
|
22 |
from modules.diarize.diarizer import Diarizer
|
23 |
from modules.vad.silero_vad import SileroVAD
|
24 |
|
25 |
|
26 |
-
class
|
27 |
def __init__(self,
|
28 |
model_dir: str = WHISPER_MODELS_DIR,
|
29 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
@@ -74,12 +75,13 @@ class WhisperBase(ABC):
|
|
74 |
audio: Union[str, BinaryIO, np.ndarray],
|
75 |
progress: gr.Progress = gr.Progress(),
|
76 |
add_timestamp: bool = True,
|
77 |
-
*
|
78 |
) -> Tuple[List[dict], float]:
|
79 |
"""
|
80 |
Run transcription with conditional pre-processing and post-processing.
|
81 |
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
|
82 |
The diarization will be performed in post-processing, if enabled.
|
|
|
83 |
|
84 |
Parameters
|
85 |
----------
|
@@ -89,8 +91,8 @@ class WhisperBase(ABC):
|
|
89 |
Indicator to show progress directly in gradio.
|
90 |
add_timestamp: bool
|
91 |
Whether to add a timestamp at the end of the filename.
|
92 |
-
*
|
93 |
-
Parameters
|
94 |
|
95 |
Returns
|
96 |
----------
|
@@ -99,28 +101,17 @@ class WhisperBase(ABC):
|
|
99 |
elapsed_time: float
|
100 |
elapsed time for running
|
101 |
"""
|
102 |
-
params =
|
103 |
-
|
104 |
-
|
105 |
-
whisper_params=params,
|
106 |
-
add_timestamp=add_timestamp
|
107 |
-
)
|
108 |
|
109 |
-
if
|
110 |
-
pass
|
111 |
-
elif params.lang == AUTOMATIC_DETECTION:
|
112 |
-
params.lang = None
|
113 |
-
else:
|
114 |
-
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
115 |
-
params.lang = language_code_dict[params.lang]
|
116 |
-
|
117 |
-
if params.is_bgm_separate:
|
118 |
music, audio, _ = self.music_separator.separate(
|
119 |
audio=audio,
|
120 |
-
model_name=
|
121 |
-
device=
|
122 |
-
segment_size=
|
123 |
-
save_file=
|
124 |
progress=progress
|
125 |
)
|
126 |
|
@@ -132,47 +123,54 @@ class WhisperBase(ABC):
|
|
132 |
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
133 |
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
134 |
|
135 |
-
if
|
136 |
self.music_separator.offload()
|
137 |
|
138 |
-
if
|
139 |
-
# Explicit value set for float('inf') from gr.Number()
|
140 |
-
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
|
141 |
-
params.max_speech_duration_s = float('inf')
|
142 |
-
|
143 |
vad_options = VadOptions(
|
144 |
-
threshold=
|
145 |
-
min_speech_duration_ms=
|
146 |
-
max_speech_duration_s=
|
147 |
-
min_silence_duration_ms=
|
148 |
-
speech_pad_ms=
|
149 |
)
|
150 |
|
151 |
-
|
152 |
audio=audio,
|
153 |
vad_parameters=vad_options,
|
154 |
progress=progress
|
155 |
)
|
156 |
|
|
|
|
|
|
|
|
|
|
|
157 |
result, elapsed_time = self.transcribe(
|
158 |
audio,
|
159 |
progress,
|
160 |
-
*
|
161 |
)
|
162 |
|
163 |
-
if
|
164 |
result = self.vad.restore_speech_timestamps(
|
165 |
segments=result,
|
166 |
speech_chunks=speech_chunks,
|
167 |
)
|
168 |
|
169 |
-
if
|
170 |
result, elapsed_time_diarization = self.diarizer.run(
|
171 |
audio=audio,
|
172 |
-
use_auth_token=
|
173 |
transcribed_result=result,
|
|
|
174 |
)
|
175 |
elapsed_time += elapsed_time_diarization
|
|
|
|
|
|
|
|
|
|
|
176 |
return result, elapsed_time
|
177 |
|
178 |
def transcribe_file(self,
|
@@ -181,7 +179,7 @@ class WhisperBase(ABC):
|
|
181 |
file_format: str = "SRT",
|
182 |
add_timestamp: bool = True,
|
183 |
progress=gr.Progress(),
|
184 |
-
*
|
185 |
) -> list:
|
186 |
"""
|
187 |
Write subtitle file from Files
|
@@ -199,8 +197,8 @@ class WhisperBase(ABC):
|
|
199 |
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
200 |
progress: gr.Progress
|
201 |
Indicator to show progress directly in gradio.
|
202 |
-
*
|
203 |
-
Parameters
|
204 |
|
205 |
Returns
|
206 |
----------
|
@@ -223,7 +221,7 @@ class WhisperBase(ABC):
|
|
223 |
file,
|
224 |
progress,
|
225 |
add_timestamp,
|
226 |
-
*
|
227 |
)
|
228 |
|
229 |
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
@@ -471,7 +469,7 @@ class WhisperBase(ABC):
|
|
471 |
if torch.cuda.is_available():
|
472 |
return "cuda"
|
473 |
elif torch.backends.mps.is_available():
|
474 |
-
if not
|
475 |
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
476 |
return "cpu"
|
477 |
return "mps"
|
@@ -512,18 +510,60 @@ class WhisperBase(ABC):
|
|
512 |
if file_path and os.path.exists(file_path):
|
513 |
os.remove(file_path)
|
514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
@staticmethod
|
516 |
def cache_parameters(
|
517 |
-
|
518 |
add_timestamp: bool
|
519 |
):
|
520 |
-
"""
|
521 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
522 |
-
|
523 |
-
|
|
|
524 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
525 |
|
526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
|
528 |
@staticmethod
|
529 |
def resample_audio(audio: Union[str, np.ndarray],
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
import ast
|
4 |
import whisper
|
5 |
import ctranslate2
|
6 |
import gradio as gr
|
|
|
15 |
from modules.uvr.music_separator import MusicSeparator
|
16 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
17 |
UVR_MODELS_DIR)
|
18 |
+
from modules.utils.constants import *
|
19 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
20 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
21 |
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
22 |
+
from modules.whisper.data_classes import *
|
23 |
from modules.diarize.diarizer import Diarizer
|
24 |
from modules.vad.silero_vad import SileroVAD
|
25 |
|
26 |
|
27 |
+
class BaseTranscriptionPipeline(ABC):
|
28 |
def __init__(self,
|
29 |
model_dir: str = WHISPER_MODELS_DIR,
|
30 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
75 |
audio: Union[str, BinaryIO, np.ndarray],
|
76 |
progress: gr.Progress = gr.Progress(),
|
77 |
add_timestamp: bool = True,
|
78 |
+
*pipeline_params,
|
79 |
) -> Tuple[List[dict], float]:
|
80 |
"""
|
81 |
Run transcription with conditional pre-processing and post-processing.
|
82 |
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
|
83 |
The diarization will be performed in post-processing, if enabled.
|
84 |
+
Due to the integration with gradio, the parameters have to be specified with a `*` wildcard.
|
85 |
|
86 |
Parameters
|
87 |
----------
|
|
|
91 |
Indicator to show progress directly in gradio.
|
92 |
add_timestamp: bool
|
93 |
Whether to add a timestamp at the end of the filename.
|
94 |
+
*pipeline_params: tuple
|
95 |
+
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
|
96 |
|
97 |
Returns
|
98 |
----------
|
|
|
101 |
elapsed_time: float
|
102 |
elapsed time for running
|
103 |
"""
|
104 |
+
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
105 |
+
params = self.validate_gradio_values(params)
|
106 |
+
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
|
|
|
|
|
|
|
107 |
|
108 |
+
if bgm_params.is_separate_bgm:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
music, audio, _ = self.music_separator.separate(
|
110 |
audio=audio,
|
111 |
+
model_name=bgm_params.model_size,
|
112 |
+
device=bgm_params.device,
|
113 |
+
segment_size=bgm_params.segment_size,
|
114 |
+
save_file=bgm_params.save_file,
|
115 |
progress=progress
|
116 |
)
|
117 |
|
|
|
123 |
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
124 |
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
125 |
|
126 |
+
if bgm_params.enable_offload:
|
127 |
self.music_separator.offload()
|
128 |
|
129 |
+
if vad_params.vad_filter:
|
|
|
|
|
|
|
|
|
130 |
vad_options = VadOptions(
|
131 |
+
threshold=vad_params.threshold,
|
132 |
+
min_speech_duration_ms=vad_params.min_speech_duration_ms,
|
133 |
+
max_speech_duration_s=vad_params.max_speech_duration_s,
|
134 |
+
min_silence_duration_ms=vad_params.min_silence_duration_ms,
|
135 |
+
speech_pad_ms=vad_params.speech_pad_ms
|
136 |
)
|
137 |
|
138 |
+
vad_processed, speech_chunks = self.vad.run(
|
139 |
audio=audio,
|
140 |
vad_parameters=vad_options,
|
141 |
progress=progress
|
142 |
)
|
143 |
|
144 |
+
if vad_processed.size > 0:
|
145 |
+
audio = vad_processed
|
146 |
+
else:
|
147 |
+
vad_params.vad_filter = False
|
148 |
+
|
149 |
result, elapsed_time = self.transcribe(
|
150 |
audio,
|
151 |
progress,
|
152 |
+
*whisper_params.to_list()
|
153 |
)
|
154 |
|
155 |
+
if vad_params.vad_filter:
|
156 |
result = self.vad.restore_speech_timestamps(
|
157 |
segments=result,
|
158 |
speech_chunks=speech_chunks,
|
159 |
)
|
160 |
|
161 |
+
if diarization_params.is_diarize:
|
162 |
result, elapsed_time_diarization = self.diarizer.run(
|
163 |
audio=audio,
|
164 |
+
use_auth_token=diarization_params.hf_token,
|
165 |
transcribed_result=result,
|
166 |
+
device=diarization_params.device
|
167 |
)
|
168 |
elapsed_time += elapsed_time_diarization
|
169 |
+
|
170 |
+
self.cache_parameters(
|
171 |
+
params=params,
|
172 |
+
add_timestamp=add_timestamp
|
173 |
+
)
|
174 |
return result, elapsed_time
|
175 |
|
176 |
def transcribe_file(self,
|
|
|
179 |
file_format: str = "SRT",
|
180 |
add_timestamp: bool = True,
|
181 |
progress=gr.Progress(),
|
182 |
+
*params,
|
183 |
) -> list:
|
184 |
"""
|
185 |
Write subtitle file from Files
|
|
|
197 |
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
198 |
progress: gr.Progress
|
199 |
Indicator to show progress directly in gradio.
|
200 |
+
*params: tuple
|
201 |
+
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
|
202 |
|
203 |
Returns
|
204 |
----------
|
|
|
221 |
file,
|
222 |
progress,
|
223 |
add_timestamp,
|
224 |
+
*params,
|
225 |
)
|
226 |
|
227 |
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
|
|
469 |
if torch.cuda.is_available():
|
470 |
return "cuda"
|
471 |
elif torch.backends.mps.is_available():
|
472 |
+
if not BaseTranscriptionPipeline.is_sparse_api_supported():
|
473 |
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
474 |
return "cpu"
|
475 |
return "mps"
|
|
|
510 |
if file_path and os.path.exists(file_path):
|
511 |
os.remove(file_path)
|
512 |
|
513 |
+
@staticmethod
|
514 |
+
def validate_gradio_values(params: TranscriptionPipelineParams):
|
515 |
+
"""
|
516 |
+
Validate gradio specific values that can't be displayed as None in the UI.
|
517 |
+
Related issue : https://github.com/gradio-app/gradio/issues/8723
|
518 |
+
"""
|
519 |
+
if params.whisper.lang is None:
|
520 |
+
pass
|
521 |
+
elif params.whisper.lang == AUTOMATIC_DETECTION:
|
522 |
+
params.whisper.lang = None
|
523 |
+
else:
|
524 |
+
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
525 |
+
params.whisper.lang = language_code_dict[params.lang]
|
526 |
+
|
527 |
+
if params.whisper.initial_prompt == GRADIO_NONE_STR:
|
528 |
+
params.whisper.initial_prompt = None
|
529 |
+
if params.whisper.prefix == GRADIO_NONE_STR:
|
530 |
+
params.whisper.prefix = None
|
531 |
+
if params.whisper.hotwords == GRADIO_NONE_STR:
|
532 |
+
params.whisper.hotwords = None
|
533 |
+
if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN:
|
534 |
+
params.whisper.max_new_tokens = None
|
535 |
+
if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN:
|
536 |
+
params.whisper.hallucination_silence_threshold = None
|
537 |
+
if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN:
|
538 |
+
params.whisper.language_detection_threshold = None
|
539 |
+
if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX:
|
540 |
+
params.vad.max_speech_duration_s = float('inf')
|
541 |
+
return params
|
542 |
+
|
543 |
@staticmethod
|
544 |
def cache_parameters(
|
545 |
+
params: TranscriptionPipelineParams,
|
546 |
add_timestamp: bool
|
547 |
):
|
548 |
+
"""Cache parameters to the yaml file"""
|
549 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
550 |
+
param_to_cache = params.to_dict()
|
551 |
+
|
552 |
+
cached_yaml = {**cached_params, **param_to_cache}
|
553 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
554 |
|
555 |
+
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
556 |
+
if supress_token and isinstance(supress_token, list):
|
557 |
+
cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
|
558 |
+
|
559 |
+
if cached_yaml["whisper"].get("lang", None) is None:
|
560 |
+
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
|
561 |
+
|
562 |
+
if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
|
563 |
+
cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
|
564 |
+
|
565 |
+
if cached_yaml is not None and cached_yaml:
|
566 |
+
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
567 |
|
568 |
@staticmethod
|
569 |
def resample_audio(audio: Union[str, np.ndarray],
|
modules/whisper/data_classes.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from typing import Optional, Dict, List, Union
|
4 |
+
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
5 |
+
from gradio_i18n import Translate, gettext as _
|
6 |
+
from enum import Enum
|
7 |
+
from copy import deepcopy
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
from modules.utils.constants import *
|
11 |
+
|
12 |
+
|
13 |
+
class WhisperImpl(Enum):
|
14 |
+
WHISPER = "whisper"
|
15 |
+
FASTER_WHISPER = "faster-whisper"
|
16 |
+
INSANELY_FAST_WHISPER = "insanely_fast_whisper"
|
17 |
+
|
18 |
+
|
19 |
+
class Segment(BaseModel):
|
20 |
+
text: Optional[str] = Field(default=None,
|
21 |
+
description="Transcription text of the segment")
|
22 |
+
start: Optional[float] = Field(default=None,
|
23 |
+
description="Start time of the segment")
|
24 |
+
end: Optional[float] = Field(default=None,
|
25 |
+
description="End time of the segment")
|
26 |
+
|
27 |
+
|
28 |
+
class BaseParams(BaseModel):
|
29 |
+
model_config = ConfigDict(protected_namespaces=())
|
30 |
+
|
31 |
+
def to_dict(self) -> Dict:
|
32 |
+
return self.model_dump()
|
33 |
+
|
34 |
+
def to_list(self) -> List:
|
35 |
+
return list(self.model_dump().values())
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def from_list(cls, data_list: List) -> 'BaseParams':
|
39 |
+
field_names = list(cls.model_fields.keys())
|
40 |
+
return cls(**dict(zip(field_names, data_list)))
|
41 |
+
|
42 |
+
|
43 |
+
class VadParams(BaseParams):
|
44 |
+
"""Voice Activity Detection parameters"""
|
45 |
+
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
|
46 |
+
threshold: float = Field(
|
47 |
+
default=0.5,
|
48 |
+
ge=0.0,
|
49 |
+
le=1.0,
|
50 |
+
description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
|
51 |
+
)
|
52 |
+
min_speech_duration_ms: int = Field(
|
53 |
+
default=250,
|
54 |
+
ge=0,
|
55 |
+
description="Final speech chunks shorter than this are discarded"
|
56 |
+
)
|
57 |
+
max_speech_duration_s: float = Field(
|
58 |
+
default=float("inf"),
|
59 |
+
gt=0,
|
60 |
+
description="Maximum duration of speech chunks in seconds"
|
61 |
+
)
|
62 |
+
min_silence_duration_ms: int = Field(
|
63 |
+
default=2000,
|
64 |
+
ge=0,
|
65 |
+
description="Minimum silence duration between speech chunks"
|
66 |
+
)
|
67 |
+
speech_pad_ms: int = Field(
|
68 |
+
default=400,
|
69 |
+
ge=0,
|
70 |
+
description="Padding added to each side of speech chunks"
|
71 |
+
)
|
72 |
+
|
73 |
+
@classmethod
|
74 |
+
def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
|
75 |
+
return [
|
76 |
+
gr.Checkbox(
|
77 |
+
label=_("Enable Silero VAD Filter"),
|
78 |
+
value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
|
79 |
+
interactive=True,
|
80 |
+
info=_("Enable this to transcribe only detected voice")
|
81 |
+
),
|
82 |
+
gr.Slider(
|
83 |
+
minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
|
84 |
+
value=defaults.get("threshold", cls.__fields__["threshold"].default),
|
85 |
+
info="Lower it to be more sensitive to small sounds."
|
86 |
+
),
|
87 |
+
gr.Number(
|
88 |
+
label="Minimum Speech Duration (ms)", precision=0,
|
89 |
+
value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
|
90 |
+
info="Final speech chunks shorter than this time are thrown out"
|
91 |
+
),
|
92 |
+
gr.Number(
|
93 |
+
label="Maximum Speech Duration (s)",
|
94 |
+
value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
|
95 |
+
info="Maximum duration of speech chunks in \"seconds\"."
|
96 |
+
),
|
97 |
+
gr.Number(
|
98 |
+
label="Minimum Silence Duration (ms)", precision=0,
|
99 |
+
value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
|
100 |
+
info="In the end of each speech chunk wait for this time before separating it"
|
101 |
+
),
|
102 |
+
gr.Number(
|
103 |
+
label="Speech Padding (ms)", precision=0,
|
104 |
+
value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
|
105 |
+
info="Final speech chunks are padded by this time each side"
|
106 |
+
)
|
107 |
+
]
|
108 |
+
|
109 |
+
|
110 |
+
class DiarizationParams(BaseParams):
|
111 |
+
"""Speaker diarization parameters"""
|
112 |
+
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
|
113 |
+
device: str = Field(default="cuda", description="Device to run Diarization model.")
|
114 |
+
hf_token: str = Field(
|
115 |
+
default="",
|
116 |
+
description="Hugging Face token for downloading diarization models"
|
117 |
+
)
|
118 |
+
|
119 |
+
@classmethod
|
120 |
+
def to_gradio_inputs(cls,
|
121 |
+
defaults: Optional[Dict] = None,
|
122 |
+
available_devices: Optional[List] = None,
|
123 |
+
device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
|
124 |
+
return [
|
125 |
+
gr.Checkbox(
|
126 |
+
label=_("Enable Diarization"),
|
127 |
+
value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
|
128 |
+
),
|
129 |
+
gr.Dropdown(
|
130 |
+
label=_("Device"),
|
131 |
+
choices=["cpu", "cuda"] if available_devices is None else available_devices,
|
132 |
+
value=defaults.get("device", device),
|
133 |
+
),
|
134 |
+
gr.Textbox(
|
135 |
+
label=_("HuggingFace Token"),
|
136 |
+
value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
|
137 |
+
info=_("This is only needed the first time you download the model")
|
138 |
+
),
|
139 |
+
]
|
140 |
+
|
141 |
+
|
142 |
+
class BGMSeparationParams(BaseParams):
|
143 |
+
"""Background music separation parameters"""
|
144 |
+
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
|
145 |
+
model_size: str = Field(
|
146 |
+
default="UVR-MDX-NET-Inst_HQ_4",
|
147 |
+
description="UVR model size"
|
148 |
+
)
|
149 |
+
device: str = Field(default="cuda", description="Device to run UVR model.")
|
150 |
+
segment_size: int = Field(
|
151 |
+
default=256,
|
152 |
+
gt=0,
|
153 |
+
description="Segment size for UVR model"
|
154 |
+
)
|
155 |
+
save_file: bool = Field(
|
156 |
+
default=False,
|
157 |
+
description="Whether to save separated audio files"
|
158 |
+
)
|
159 |
+
enable_offload: bool = Field(
|
160 |
+
default=True,
|
161 |
+
description="Offload UVR model after transcription"
|
162 |
+
)
|
163 |
+
|
164 |
+
@classmethod
|
165 |
+
def to_gradio_input(cls,
|
166 |
+
defaults: Optional[Dict] = None,
|
167 |
+
available_devices: Optional[List] = None,
|
168 |
+
device: Optional[str] = None,
|
169 |
+
available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
|
170 |
+
return [
|
171 |
+
gr.Checkbox(
|
172 |
+
label=_("Enable Background Music Remover Filter"),
|
173 |
+
value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
|
174 |
+
interactive=True,
|
175 |
+
info=_("Enabling this will remove background music")
|
176 |
+
),
|
177 |
+
gr.Dropdown(
|
178 |
+
label=_("Model"),
|
179 |
+
choices=["UVR-MDX-NET-Inst_HQ_4",
|
180 |
+
"UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
|
181 |
+
value=defaults.get("model_size", cls.__fields__["model_size"].default),
|
182 |
+
),
|
183 |
+
gr.Dropdown(
|
184 |
+
label=_("Device"),
|
185 |
+
choices=["cpu", "cuda"] if available_devices is None else available_devices,
|
186 |
+
value=defaults.get("device", device),
|
187 |
+
),
|
188 |
+
gr.Number(
|
189 |
+
label="Segment Size",
|
190 |
+
value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
|
191 |
+
precision=0,
|
192 |
+
info="Segment size for UVR model"
|
193 |
+
),
|
194 |
+
gr.Checkbox(
|
195 |
+
label=_("Save separated files to output"),
|
196 |
+
value=defaults.get("save_file", cls.__fields__["save_file"].default),
|
197 |
+
),
|
198 |
+
gr.Checkbox(
|
199 |
+
label=_("Offload sub model after removing background music"),
|
200 |
+
value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
|
201 |
+
)
|
202 |
+
]
|
203 |
+
|
204 |
+
|
205 |
+
class WhisperParams(BaseParams):
|
206 |
+
"""Whisper parameters"""
|
207 |
+
model_size: str = Field(default="large-v2", description="Whisper model size")
|
208 |
+
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
|
209 |
+
is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
|
210 |
+
beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
|
211 |
+
log_prob_threshold: float = Field(
|
212 |
+
default=-1.0,
|
213 |
+
description="Threshold for average log probability of sampled tokens"
|
214 |
+
)
|
215 |
+
no_speech_threshold: float = Field(
|
216 |
+
default=0.6,
|
217 |
+
ge=0.0,
|
218 |
+
le=1.0,
|
219 |
+
description="Threshold for detecting silence"
|
220 |
+
)
|
221 |
+
compute_type: str = Field(default="float16", description="Computation type for transcription")
|
222 |
+
best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
|
223 |
+
patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
|
224 |
+
condition_on_previous_text: bool = Field(
|
225 |
+
default=True,
|
226 |
+
description="Use previous output as prompt for next window"
|
227 |
+
)
|
228 |
+
prompt_reset_on_temperature: float = Field(
|
229 |
+
default=0.5,
|
230 |
+
ge=0.0,
|
231 |
+
le=1.0,
|
232 |
+
description="Temperature threshold for resetting prompt"
|
233 |
+
)
|
234 |
+
initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
|
235 |
+
temperature: float = Field(
|
236 |
+
default=0.0,
|
237 |
+
ge=0.0,
|
238 |
+
description="Temperature for sampling"
|
239 |
+
)
|
240 |
+
compression_ratio_threshold: float = Field(
|
241 |
+
default=2.4,
|
242 |
+
gt=0,
|
243 |
+
description="Threshold for gzip compression ratio"
|
244 |
+
)
|
245 |
+
length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
|
246 |
+
repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
|
247 |
+
no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
|
248 |
+
prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
|
249 |
+
suppress_blank: bool = Field(
|
250 |
+
default=True,
|
251 |
+
description="Suppress blank outputs at start of sampling"
|
252 |
+
)
|
253 |
+
suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
|
254 |
+
max_initial_timestamp: float = Field(
|
255 |
+
default=0.0,
|
256 |
+
ge=0.0,
|
257 |
+
description="Maximum initial timestamp"
|
258 |
+
)
|
259 |
+
word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
|
260 |
+
prepend_punctuations: Optional[str] = Field(
|
261 |
+
default="\"'“¿([{-",
|
262 |
+
description="Punctuations to merge with next word"
|
263 |
+
)
|
264 |
+
append_punctuations: Optional[str] = Field(
|
265 |
+
default="\"'.。,,!!??::”)]}、",
|
266 |
+
description="Punctuations to merge with previous word"
|
267 |
+
)
|
268 |
+
max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
|
269 |
+
chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
|
270 |
+
hallucination_silence_threshold: Optional[float] = Field(
|
271 |
+
default=None,
|
272 |
+
description="Threshold for skipping silent periods in hallucination detection"
|
273 |
+
)
|
274 |
+
hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
|
275 |
+
language_detection_threshold: Optional[float] = Field(
|
276 |
+
default=None,
|
277 |
+
description="Threshold for language detection probability"
|
278 |
+
)
|
279 |
+
language_detection_segments: int = Field(
|
280 |
+
default=1,
|
281 |
+
gt=0,
|
282 |
+
description="Number of segments for language detection"
|
283 |
+
)
|
284 |
+
batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
|
285 |
+
|
286 |
+
@field_validator('lang')
|
287 |
+
def validate_lang(cls, v):
|
288 |
+
from modules.utils.constants import AUTOMATIC_DETECTION
|
289 |
+
return None if v == AUTOMATIC_DETECTION.unwrap() else v
|
290 |
+
|
291 |
+
@field_validator('suppress_tokens')
|
292 |
+
def validate_supress_tokens(cls, v):
|
293 |
+
import ast
|
294 |
+
try:
|
295 |
+
if isinstance(v, str):
|
296 |
+
suppress_tokens = ast.literal_eval(v)
|
297 |
+
if not isinstance(suppress_tokens, list):
|
298 |
+
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
299 |
+
return suppress_tokens
|
300 |
+
if isinstance(v, list):
|
301 |
+
return v
|
302 |
+
except Exception as e:
|
303 |
+
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
|
304 |
+
|
305 |
+
@classmethod
|
306 |
+
def to_gradio_inputs(cls,
|
307 |
+
defaults: Optional[Dict] = None,
|
308 |
+
only_advanced: Optional[bool] = True,
|
309 |
+
whisper_type: Optional[str] = None,
|
310 |
+
available_models: Optional[List] = None,
|
311 |
+
available_langs: Optional[List] = None,
|
312 |
+
available_compute_types: Optional[List] = None,
|
313 |
+
compute_type: Optional[str] = None):
|
314 |
+
whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower()
|
315 |
+
|
316 |
+
inputs = []
|
317 |
+
if not only_advanced:
|
318 |
+
inputs += [
|
319 |
+
gr.Dropdown(
|
320 |
+
label=_("Model"),
|
321 |
+
choices=available_models,
|
322 |
+
value=defaults.get("model_size", cls.__fields__["model_size"].default),
|
323 |
+
),
|
324 |
+
gr.Dropdown(
|
325 |
+
label=_("Language"),
|
326 |
+
choices=available_langs,
|
327 |
+
value=defaults.get("lang", AUTOMATIC_DETECTION),
|
328 |
+
),
|
329 |
+
gr.Checkbox(
|
330 |
+
label=_("Translate to English?"),
|
331 |
+
value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
|
332 |
+
),
|
333 |
+
]
|
334 |
+
|
335 |
+
inputs += [
|
336 |
+
gr.Number(
|
337 |
+
label="Beam Size",
|
338 |
+
value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
|
339 |
+
precision=0,
|
340 |
+
info="Beam size for decoding"
|
341 |
+
),
|
342 |
+
gr.Number(
|
343 |
+
label="Log Probability Threshold",
|
344 |
+
value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
|
345 |
+
info="Threshold for average log probability of sampled tokens"
|
346 |
+
),
|
347 |
+
gr.Number(
|
348 |
+
label="No Speech Threshold",
|
349 |
+
value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
|
350 |
+
info="Threshold for detecting silence"
|
351 |
+
),
|
352 |
+
gr.Dropdown(
|
353 |
+
label="Compute Type",
|
354 |
+
choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types,
|
355 |
+
value=defaults.get("compute_type", compute_type),
|
356 |
+
info="Computation type for transcription"
|
357 |
+
),
|
358 |
+
gr.Number(
|
359 |
+
label="Best Of",
|
360 |
+
value=defaults.get("best_of", cls.__fields__["best_of"].default),
|
361 |
+
precision=0,
|
362 |
+
info="Number of candidates when sampling"
|
363 |
+
),
|
364 |
+
gr.Number(
|
365 |
+
label="Patience",
|
366 |
+
value=defaults.get("patience", cls.__fields__["patience"].default),
|
367 |
+
info="Beam search patience factor"
|
368 |
+
),
|
369 |
+
gr.Checkbox(
|
370 |
+
label="Condition On Previous Text",
|
371 |
+
value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
|
372 |
+
info="Use previous output as prompt for next window"
|
373 |
+
),
|
374 |
+
gr.Slider(
|
375 |
+
label="Prompt Reset On Temperature",
|
376 |
+
value=defaults.get("prompt_reset_on_temperature",
|
377 |
+
cls.__fields__["prompt_reset_on_temperature"].default),
|
378 |
+
minimum=0,
|
379 |
+
maximum=1,
|
380 |
+
step=0.01,
|
381 |
+
info="Temperature threshold for resetting prompt"
|
382 |
+
),
|
383 |
+
gr.Textbox(
|
384 |
+
label="Initial Prompt",
|
385 |
+
value=defaults.get("initial_prompt", GRADIO_NONE_STR),
|
386 |
+
info="Initial prompt for first window"
|
387 |
+
),
|
388 |
+
gr.Slider(
|
389 |
+
label="Temperature",
|
390 |
+
value=defaults.get("temperature", cls.__fields__["temperature"].default),
|
391 |
+
minimum=0.0,
|
392 |
+
step=0.01,
|
393 |
+
maximum=1.0,
|
394 |
+
info="Temperature for sampling"
|
395 |
+
),
|
396 |
+
gr.Number(
|
397 |
+
label="Compression Ratio Threshold",
|
398 |
+
value=defaults.get("compression_ratio_threshold",
|
399 |
+
cls.__fields__["compression_ratio_threshold"].default),
|
400 |
+
info="Threshold for gzip compression ratio"
|
401 |
+
)
|
402 |
+
]
|
403 |
+
|
404 |
+
faster_whisper_inputs = [
|
405 |
+
gr.Number(
|
406 |
+
label="Length Penalty",
|
407 |
+
value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
|
408 |
+
info="Exponential length penalty",
|
409 |
+
),
|
410 |
+
gr.Number(
|
411 |
+
label="Repetition Penalty",
|
412 |
+
value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
|
413 |
+
info="Penalty for repeated tokens"
|
414 |
+
),
|
415 |
+
gr.Number(
|
416 |
+
label="No Repeat N-gram Size",
|
417 |
+
value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
|
418 |
+
precision=0,
|
419 |
+
info="Size of n-grams to prevent repetition"
|
420 |
+
),
|
421 |
+
gr.Textbox(
|
422 |
+
label="Prefix",
|
423 |
+
value=defaults.get("prefix", GRADIO_NONE_STR),
|
424 |
+
info="Prefix text for first window"
|
425 |
+
),
|
426 |
+
gr.Checkbox(
|
427 |
+
label="Suppress Blank",
|
428 |
+
value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
|
429 |
+
info="Suppress blank outputs at start of sampling"
|
430 |
+
),
|
431 |
+
gr.Textbox(
|
432 |
+
label="Suppress Tokens",
|
433 |
+
value=defaults.get("suppress_tokens", "[-1]"),
|
434 |
+
info="Token IDs to suppress"
|
435 |
+
),
|
436 |
+
gr.Number(
|
437 |
+
label="Max Initial Timestamp",
|
438 |
+
value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
|
439 |
+
info="Maximum initial timestamp"
|
440 |
+
),
|
441 |
+
gr.Checkbox(
|
442 |
+
label="Word Timestamps",
|
443 |
+
value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
|
444 |
+
info="Extract word-level timestamps"
|
445 |
+
),
|
446 |
+
gr.Textbox(
|
447 |
+
label="Prepend Punctuations",
|
448 |
+
value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
|
449 |
+
info="Punctuations to merge with next word"
|
450 |
+
),
|
451 |
+
gr.Textbox(
|
452 |
+
label="Append Punctuations",
|
453 |
+
value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
|
454 |
+
info="Punctuations to merge with previous word"
|
455 |
+
),
|
456 |
+
gr.Number(
|
457 |
+
label="Max New Tokens",
|
458 |
+
value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
|
459 |
+
precision=0,
|
460 |
+
info="Maximum number of new tokens per chunk"
|
461 |
+
),
|
462 |
+
gr.Number(
|
463 |
+
label="Chunk Length (s)",
|
464 |
+
value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
|
465 |
+
precision=0,
|
466 |
+
info="Length of audio segments in seconds"
|
467 |
+
),
|
468 |
+
gr.Number(
|
469 |
+
label="Hallucination Silence Threshold (sec)",
|
470 |
+
value=defaults.get("hallucination_silence_threshold",
|
471 |
+
GRADIO_NONE_NUMBER_MIN),
|
472 |
+
info="Threshold for skipping silent periods in hallucination detection"
|
473 |
+
),
|
474 |
+
gr.Textbox(
|
475 |
+
label="Hotwords",
|
476 |
+
value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
|
477 |
+
info="Hotwords/hint phrases for the model"
|
478 |
+
),
|
479 |
+
gr.Number(
|
480 |
+
label="Language Detection Threshold",
|
481 |
+
value=defaults.get("language_detection_threshold",
|
482 |
+
GRADIO_NONE_NUMBER_MIN),
|
483 |
+
info="Threshold for language detection probability"
|
484 |
+
),
|
485 |
+
gr.Number(
|
486 |
+
label="Language Detection Segments",
|
487 |
+
value=defaults.get("language_detection_segments",
|
488 |
+
cls.__fields__["language_detection_segments"].default),
|
489 |
+
precision=0,
|
490 |
+
info="Number of segments for language detection"
|
491 |
+
)
|
492 |
+
]
|
493 |
+
|
494 |
+
insanely_fast_whisper_inputs = [
|
495 |
+
gr.Number(
|
496 |
+
label="Batch Size",
|
497 |
+
value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
|
498 |
+
precision=0,
|
499 |
+
info="Batch size for processing"
|
500 |
+
)
|
501 |
+
]
|
502 |
+
|
503 |
+
if whisper_type != WhisperImpl.FASTER_WHISPER.value:
|
504 |
+
for input_component in faster_whisper_inputs:
|
505 |
+
input_component.visible = False
|
506 |
+
|
507 |
+
if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value:
|
508 |
+
for input_component in insanely_fast_whisper_inputs:
|
509 |
+
input_component.visible = False
|
510 |
+
|
511 |
+
inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
|
512 |
+
|
513 |
+
return inputs
|
514 |
+
|
515 |
+
|
516 |
+
class TranscriptionPipelineParams(BaseModel):
|
517 |
+
"""Transcription pipeline parameters"""
|
518 |
+
whisper: WhisperParams = Field(default_factory=WhisperParams)
|
519 |
+
vad: VadParams = Field(default_factory=VadParams)
|
520 |
+
diarization: DiarizationParams = Field(default_factory=DiarizationParams)
|
521 |
+
bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
|
522 |
+
|
523 |
+
def to_dict(self) -> Dict:
|
524 |
+
data = {
|
525 |
+
"whisper": self.whisper.to_dict(),
|
526 |
+
"vad": self.vad.to_dict(),
|
527 |
+
"diarization": self.diarization.to_dict(),
|
528 |
+
"bgm_separation": self.bgm_separation.to_dict()
|
529 |
+
}
|
530 |
+
return data
|
531 |
+
|
532 |
+
def to_list(self) -> List:
|
533 |
+
"""
|
534 |
+
Convert data class to the list because I have to pass the parameters as a list in the gradio.
|
535 |
+
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
536 |
+
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
537 |
+
"""
|
538 |
+
whisper_list = self.whisper.to_list()
|
539 |
+
vad_list = self.vad.to_list()
|
540 |
+
diarization_list = self.diarization.to_list()
|
541 |
+
bgm_sep_list = self.bgm_separation.to_list()
|
542 |
+
return whisper_list + vad_list + diarization_list + bgm_sep_list
|
543 |
+
|
544 |
+
@staticmethod
|
545 |
+
def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
|
546 |
+
"""Convert list to the data class again to use it in a function."""
|
547 |
+
data_list = deepcopy(pipeline_list)
|
548 |
+
|
549 |
+
whisper_list = data_list[0:len(WhisperParams.__annotations__)]
|
550 |
+
data_list = data_list[len(WhisperParams.__annotations__):]
|
551 |
+
|
552 |
+
vad_list = data_list[0:len(VadParams.__annotations__)]
|
553 |
+
data_list = data_list[len(VadParams.__annotations__):]
|
554 |
+
|
555 |
+
diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
|
556 |
+
data_list = data_list[len(DiarizationParams.__annotations__):]
|
557 |
+
|
558 |
+
bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
|
559 |
+
|
560 |
+
return TranscriptionPipelineParams(
|
561 |
+
whisper=WhisperParams.from_list(whisper_list),
|
562 |
+
vad=VadParams.from_list(vad_list),
|
563 |
+
diarization=DiarizationParams.from_list(diarization_list),
|
564 |
+
bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
|
565 |
+
)
|
modules/whisper/faster_whisper_inference.py
CHANGED
@@ -12,11 +12,11 @@ import gradio as gr
|
|
12 |
from argparse import Namespace
|
13 |
|
14 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
-
from modules.whisper.
|
16 |
-
from modules.whisper.
|
17 |
|
18 |
|
19 |
-
class FasterWhisperInference(
|
20 |
def __init__(self,
|
21 |
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
|
|
40 |
audio: Union[str, BinaryIO, np.ndarray],
|
41 |
progress: gr.Progress = gr.Progress(),
|
42 |
*whisper_params,
|
43 |
-
) -> Tuple[List[
|
44 |
"""
|
45 |
transcribe method for faster-whisper.
|
46 |
|
@@ -55,28 +55,18 @@ class FasterWhisperInference(WhisperBase):
|
|
55 |
|
56 |
Returns
|
57 |
----------
|
58 |
-
segments_result: List[
|
59 |
-
list of
|
60 |
elapsed_time: float
|
61 |
elapsed time for transcription
|
62 |
"""
|
63 |
start_time = time.time()
|
64 |
|
65 |
-
params =
|
66 |
|
67 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
68 |
self.update_model(params.model_size, params.compute_type, progress)
|
69 |
|
70 |
-
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
|
71 |
-
if not params.initial_prompt:
|
72 |
-
params.initial_prompt = None
|
73 |
-
if not params.prefix:
|
74 |
-
params.prefix = None
|
75 |
-
if not params.hotwords:
|
76 |
-
params.hotwords = None
|
77 |
-
|
78 |
-
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
|
79 |
-
|
80 |
segments, info = self.model.transcribe(
|
81 |
audio=audio,
|
82 |
language=params.lang,
|
@@ -112,11 +102,11 @@ class FasterWhisperInference(WhisperBase):
|
|
112 |
segments_result = []
|
113 |
for segment in segments:
|
114 |
progress(segment.start / info.duration, desc="Transcribing..")
|
115 |
-
segments_result.append(
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
|
121 |
elapsed_time = time.time() - start_time
|
122 |
return segments_result, elapsed_time
|
|
|
12 |
from argparse import Namespace
|
13 |
|
14 |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
+
from modules.whisper.data_classes import *
|
16 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
17 |
|
18 |
|
19 |
+
class FasterWhisperInference(BaseTranscriptionPipeline):
|
20 |
def __init__(self,
|
21 |
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
40 |
audio: Union[str, BinaryIO, np.ndarray],
|
41 |
progress: gr.Progress = gr.Progress(),
|
42 |
*whisper_params,
|
43 |
+
) -> Tuple[List[Segment], float]:
|
44 |
"""
|
45 |
transcribe method for faster-whisper.
|
46 |
|
|
|
55 |
|
56 |
Returns
|
57 |
----------
|
58 |
+
segments_result: List[Segment]
|
59 |
+
list of Segment that includes start, end timestamps and transcribed text
|
60 |
elapsed_time: float
|
61 |
elapsed time for transcription
|
62 |
"""
|
63 |
start_time = time.time()
|
64 |
|
65 |
+
params = WhisperParams.from_list(list(whisper_params))
|
66 |
|
67 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
68 |
self.update_model(params.model_size, params.compute_type, progress)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
segments, info = self.model.transcribe(
|
71 |
audio=audio,
|
72 |
language=params.lang,
|
|
|
102 |
segments_result = []
|
103 |
for segment in segments:
|
104 |
progress(segment.start / info.duration, desc="Transcribing..")
|
105 |
+
segments_result.append(Segment(
|
106 |
+
start=segment.start,
|
107 |
+
end=segment.end,
|
108 |
+
text=segment.text
|
109 |
+
))
|
110 |
|
111 |
elapsed_time = time.time() - start_time
|
112 |
return segments_result, elapsed_time
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
@@ -12,11 +12,11 @@ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
|
12 |
from argparse import Namespace
|
13 |
|
14 |
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
-
from modules.whisper.
|
16 |
-
from modules.whisper.
|
17 |
|
18 |
|
19 |
-
class InsanelyFastWhisperInference(
|
20 |
def __init__(self,
|
21 |
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
@@ -40,7 +40,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
40 |
audio: Union[str, np.ndarray, torch.Tensor],
|
41 |
progress: gr.Progress = gr.Progress(),
|
42 |
*whisper_params,
|
43 |
-
) -> Tuple[List[
|
44 |
"""
|
45 |
transcribe method for faster-whisper.
|
46 |
|
@@ -55,13 +55,13 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
55 |
|
56 |
Returns
|
57 |
----------
|
58 |
-
segments_result: List[
|
59 |
-
list of
|
60 |
elapsed_time: float
|
61 |
elapsed time for transcription
|
62 |
"""
|
63 |
start_time = time.time()
|
64 |
-
params =
|
65 |
|
66 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
67 |
self.update_model(params.model_size, params.compute_type, progress)
|
@@ -95,9 +95,17 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
95 |
generate_kwargs=kwargs
|
96 |
)
|
97 |
|
98 |
-
segments_result =
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
elapsed_time = time.time() - start_time
|
102 |
return segments_result, elapsed_time
|
103 |
|
|
|
12 |
from argparse import Namespace
|
13 |
|
14 |
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
15 |
+
from modules.whisper.data_classes import *
|
16 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
17 |
|
18 |
|
19 |
+
class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
|
20 |
def __init__(self,
|
21 |
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
22 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
40 |
audio: Union[str, np.ndarray, torch.Tensor],
|
41 |
progress: gr.Progress = gr.Progress(),
|
42 |
*whisper_params,
|
43 |
+
) -> Tuple[List[Segment], float]:
|
44 |
"""
|
45 |
transcribe method for faster-whisper.
|
46 |
|
|
|
55 |
|
56 |
Returns
|
57 |
----------
|
58 |
+
segments_result: List[Segment]
|
59 |
+
list of Segment that includes start, end timestamps and transcribed text
|
60 |
elapsed_time: float
|
61 |
elapsed time for transcription
|
62 |
"""
|
63 |
start_time = time.time()
|
64 |
+
params = WhisperParams.from_list(list(whisper_params))
|
65 |
|
66 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
67 |
self.update_model(params.model_size, params.compute_type, progress)
|
|
|
95 |
generate_kwargs=kwargs
|
96 |
)
|
97 |
|
98 |
+
segments_result = []
|
99 |
+
for item in segments["chunks"]:
|
100 |
+
start, end = item["timestamp"][0], item["timestamp"][1]
|
101 |
+
if end is None:
|
102 |
+
end = start
|
103 |
+
segments_result.append(Segment(
|
104 |
+
text=item["text"],
|
105 |
+
start=start,
|
106 |
+
end=end
|
107 |
+
))
|
108 |
+
|
109 |
elapsed_time = time.time() - start_time
|
110 |
return segments_result, elapsed_time
|
111 |
|
modules/whisper/whisper_Inference.py
CHANGED
@@ -8,11 +8,11 @@ import os
|
|
8 |
from argparse import Namespace
|
9 |
|
10 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
11 |
-
from modules.whisper.
|
12 |
-
from modules.whisper.
|
13 |
|
14 |
|
15 |
-
class WhisperInference(
|
16 |
def __init__(self,
|
17 |
model_dir: str = WHISPER_MODELS_DIR,
|
18 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
@@ -30,7 +30,7 @@ class WhisperInference(WhisperBase):
|
|
30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
31 |
progress: gr.Progress = gr.Progress(),
|
32 |
*whisper_params,
|
33 |
-
) -> Tuple[List[
|
34 |
"""
|
35 |
transcribe method for faster-whisper.
|
36 |
|
@@ -45,13 +45,13 @@ class WhisperInference(WhisperBase):
|
|
45 |
|
46 |
Returns
|
47 |
----------
|
48 |
-
segments_result: List[
|
49 |
-
list of
|
50 |
elapsed_time: float
|
51 |
elapsed time for transcription
|
52 |
"""
|
53 |
start_time = time.time()
|
54 |
-
params =
|
55 |
|
56 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
57 |
self.update_model(params.model_size, params.compute_type, progress)
|
@@ -59,21 +59,28 @@ class WhisperInference(WhisperBase):
|
|
59 |
def progress_callback(progress_value):
|
60 |
progress(progress_value, desc="Transcribing..")
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
|
|
77 |
return segments_result, elapsed_time
|
78 |
|
79 |
def update_model(self,
|
|
|
8 |
from argparse import Namespace
|
9 |
|
10 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
11 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
12 |
+
from modules.whisper.data_classes import *
|
13 |
|
14 |
|
15 |
+
class WhisperInference(BaseTranscriptionPipeline):
|
16 |
def __init__(self,
|
17 |
model_dir: str = WHISPER_MODELS_DIR,
|
18 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
|
|
30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
31 |
progress: gr.Progress = gr.Progress(),
|
32 |
*whisper_params,
|
33 |
+
) -> Tuple[List[Segment], float]:
|
34 |
"""
|
35 |
transcribe method for faster-whisper.
|
36 |
|
|
|
45 |
|
46 |
Returns
|
47 |
----------
|
48 |
+
segments_result: List[Segment]
|
49 |
+
list of Segment that includes start, end timestamps and transcribed text
|
50 |
elapsed_time: float
|
51 |
elapsed time for transcription
|
52 |
"""
|
53 |
start_time = time.time()
|
54 |
+
params = WhisperParams.from_list(list(whisper_params))
|
55 |
|
56 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
57 |
self.update_model(params.model_size, params.compute_type, progress)
|
|
|
59 |
def progress_callback(progress_value):
|
60 |
progress(progress_value, desc="Transcribing..")
|
61 |
|
62 |
+
result = self.model.transcribe(audio=audio,
|
63 |
+
language=params.lang,
|
64 |
+
verbose=False,
|
65 |
+
beam_size=params.beam_size,
|
66 |
+
logprob_threshold=params.log_prob_threshold,
|
67 |
+
no_speech_threshold=params.no_speech_threshold,
|
68 |
+
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
69 |
+
fp16=True if params.compute_type == "float16" else False,
|
70 |
+
best_of=params.best_of,
|
71 |
+
patience=params.patience,
|
72 |
+
temperature=params.temperature,
|
73 |
+
compression_ratio_threshold=params.compression_ratio_threshold,
|
74 |
+
progress_callback=progress_callback,)["segments"]
|
75 |
+
segments_result = []
|
76 |
+
for segment in result:
|
77 |
+
segments_result.append(Segment(
|
78 |
+
start=segment["start"],
|
79 |
+
end=segment["end"],
|
80 |
+
text=segment["text"]
|
81 |
+
))
|
82 |
|
83 |
+
elapsed_time = time.time() - start_time
|
84 |
return segments_result, elapsed_time
|
85 |
|
86 |
def update_model(self,
|
modules/whisper/whisper_factory.py
CHANGED
@@ -6,7 +6,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
|
|
6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
7 |
from modules.whisper.whisper_Inference import WhisperInference
|
8 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
9 |
-
from modules.whisper.
|
|
|
10 |
|
11 |
|
12 |
class WhisperFactory:
|
@@ -19,7 +20,7 @@ class WhisperFactory:
|
|
19 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
20 |
uvr_model_dir: str = UVR_MODELS_DIR,
|
21 |
output_dir: str = OUTPUT_DIR,
|
22 |
-
) -> "
|
23 |
"""
|
24 |
Create a whisper inference class based on the provided whisper_type.
|
25 |
|
@@ -45,36 +46,29 @@ class WhisperFactory:
|
|
45 |
|
46 |
Returns
|
47 |
-------
|
48 |
-
|
49 |
An instance of the appropriate whisper inference class based on the whisper_type.
|
50 |
"""
|
51 |
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
52 |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
53 |
|
54 |
-
whisper_type = whisper_type.
|
55 |
|
56 |
-
|
57 |
-
whisper_typos = ["whisper"]
|
58 |
-
insanely_fast_whisper_typos = [
|
59 |
-
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
60 |
-
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
|
61 |
-
]
|
62 |
-
|
63 |
-
if whisper_type in faster_whisper_typos:
|
64 |
return FasterWhisperInference(
|
65 |
model_dir=faster_whisper_model_dir,
|
66 |
output_dir=output_dir,
|
67 |
diarization_model_dir=diarization_model_dir,
|
68 |
uvr_model_dir=uvr_model_dir
|
69 |
)
|
70 |
-
elif whisper_type
|
71 |
return WhisperInference(
|
72 |
model_dir=whisper_model_dir,
|
73 |
output_dir=output_dir,
|
74 |
diarization_model_dir=diarization_model_dir,
|
75 |
uvr_model_dir=uvr_model_dir
|
76 |
)
|
77 |
-
elif whisper_type
|
78 |
return InsanelyFastWhisperInference(
|
79 |
model_dir=insanely_fast_whisper_model_dir,
|
80 |
output_dir=output_dir,
|
|
|
6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
7 |
from modules.whisper.whisper_Inference import WhisperInference
|
8 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
9 |
+
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
10 |
+
from modules.whisper.data_classes import *
|
11 |
|
12 |
|
13 |
class WhisperFactory:
|
|
|
20 |
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
21 |
uvr_model_dir: str = UVR_MODELS_DIR,
|
22 |
output_dir: str = OUTPUT_DIR,
|
23 |
+
) -> "BaseTranscriptionPipeline":
|
24 |
"""
|
25 |
Create a whisper inference class based on the provided whisper_type.
|
26 |
|
|
|
46 |
|
47 |
Returns
|
48 |
-------
|
49 |
+
BaseTranscriptionPipeline
|
50 |
An instance of the appropriate whisper inference class based on the whisper_type.
|
51 |
"""
|
52 |
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
53 |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
54 |
|
55 |
+
whisper_type = whisper_type.strip().lower()
|
56 |
|
57 |
+
if whisper_type == WhisperImpl.FASTER_WHISPER.value:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
return FasterWhisperInference(
|
59 |
model_dir=faster_whisper_model_dir,
|
60 |
output_dir=output_dir,
|
61 |
diarization_model_dir=diarization_model_dir,
|
62 |
uvr_model_dir=uvr_model_dir
|
63 |
)
|
64 |
+
elif whisper_type == WhisperImpl.WHISPER.value:
|
65 |
return WhisperInference(
|
66 |
model_dir=whisper_model_dir,
|
67 |
output_dir=output_dir,
|
68 |
diarization_model_dir=diarization_model_dir,
|
69 |
uvr_model_dir=uvr_model_dir
|
70 |
)
|
71 |
+
elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value:
|
72 |
return InsanelyFastWhisperInference(
|
73 |
model_dir=insanely_fast_whisper_model_dir,
|
74 |
output_dir=output_dir,
|
modules/whisper/whisper_parameter.py
DELETED
@@ -1,371 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass, fields
|
2 |
-
import gradio as gr
|
3 |
-
from typing import Optional, Dict
|
4 |
-
import yaml
|
5 |
-
|
6 |
-
from modules.utils.constants import AUTOMATIC_DETECTION
|
7 |
-
|
8 |
-
|
9 |
-
@dataclass
|
10 |
-
class WhisperParameters:
|
11 |
-
model_size: gr.Dropdown
|
12 |
-
lang: gr.Dropdown
|
13 |
-
is_translate: gr.Checkbox
|
14 |
-
beam_size: gr.Number
|
15 |
-
log_prob_threshold: gr.Number
|
16 |
-
no_speech_threshold: gr.Number
|
17 |
-
compute_type: gr.Dropdown
|
18 |
-
best_of: gr.Number
|
19 |
-
patience: gr.Number
|
20 |
-
condition_on_previous_text: gr.Checkbox
|
21 |
-
prompt_reset_on_temperature: gr.Slider
|
22 |
-
initial_prompt: gr.Textbox
|
23 |
-
temperature: gr.Slider
|
24 |
-
compression_ratio_threshold: gr.Number
|
25 |
-
vad_filter: gr.Checkbox
|
26 |
-
threshold: gr.Slider
|
27 |
-
min_speech_duration_ms: gr.Number
|
28 |
-
max_speech_duration_s: gr.Number
|
29 |
-
min_silence_duration_ms: gr.Number
|
30 |
-
speech_pad_ms: gr.Number
|
31 |
-
batch_size: gr.Number
|
32 |
-
is_diarize: gr.Checkbox
|
33 |
-
hf_token: gr.Textbox
|
34 |
-
diarization_device: gr.Dropdown
|
35 |
-
length_penalty: gr.Number
|
36 |
-
repetition_penalty: gr.Number
|
37 |
-
no_repeat_ngram_size: gr.Number
|
38 |
-
prefix: gr.Textbox
|
39 |
-
suppress_blank: gr.Checkbox
|
40 |
-
suppress_tokens: gr.Textbox
|
41 |
-
max_initial_timestamp: gr.Number
|
42 |
-
word_timestamps: gr.Checkbox
|
43 |
-
prepend_punctuations: gr.Textbox
|
44 |
-
append_punctuations: gr.Textbox
|
45 |
-
max_new_tokens: gr.Number
|
46 |
-
chunk_length: gr.Number
|
47 |
-
hallucination_silence_threshold: gr.Number
|
48 |
-
hotwords: gr.Textbox
|
49 |
-
language_detection_threshold: gr.Number
|
50 |
-
language_detection_segments: gr.Number
|
51 |
-
is_bgm_separate: gr.Checkbox
|
52 |
-
uvr_model_size: gr.Dropdown
|
53 |
-
uvr_device: gr.Dropdown
|
54 |
-
uvr_segment_size: gr.Number
|
55 |
-
uvr_save_file: gr.Checkbox
|
56 |
-
uvr_enable_offload: gr.Checkbox
|
57 |
-
"""
|
58 |
-
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
59 |
-
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
60 |
-
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
61 |
-
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
62 |
-
|
63 |
-
Attributes
|
64 |
-
----------
|
65 |
-
model_size: gr.Dropdown
|
66 |
-
Whisper model size.
|
67 |
-
|
68 |
-
lang: gr.Dropdown
|
69 |
-
Source language of the file to transcribe.
|
70 |
-
|
71 |
-
is_translate: gr.Checkbox
|
72 |
-
Boolean value that determines whether to translate to English.
|
73 |
-
It's Whisper's feature to translate speech from another language directly into English end-to-end.
|
74 |
-
|
75 |
-
beam_size: gr.Number
|
76 |
-
Int value that is used for decoding option.
|
77 |
-
|
78 |
-
log_prob_threshold: gr.Number
|
79 |
-
If the average log probability over sampled tokens is below this value, treat as failed.
|
80 |
-
|
81 |
-
no_speech_threshold: gr.Number
|
82 |
-
If the no_speech probability is higher than this value AND
|
83 |
-
the average log probability over sampled tokens is below `log_prob_threshold`,
|
84 |
-
consider the segment as silent.
|
85 |
-
|
86 |
-
compute_type: gr.Dropdown
|
87 |
-
compute type for transcription.
|
88 |
-
see more info : https://opennmt.net/CTranslate2/quantization.html
|
89 |
-
|
90 |
-
best_of: gr.Number
|
91 |
-
Number of candidates when sampling with non-zero temperature.
|
92 |
-
|
93 |
-
patience: gr.Number
|
94 |
-
Beam search patience factor.
|
95 |
-
|
96 |
-
condition_on_previous_text: gr.Checkbox
|
97 |
-
if True, the previous output of the model is provided as a prompt for the next window;
|
98 |
-
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
99 |
-
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
100 |
-
|
101 |
-
initial_prompt: gr.Textbox
|
102 |
-
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
103 |
-
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
104 |
-
to make it more likely to predict those word correctly.
|
105 |
-
|
106 |
-
temperature: gr.Slider
|
107 |
-
Temperature for sampling. It can be a tuple of temperatures,
|
108 |
-
which will be successively used upon failures according to either
|
109 |
-
`compression_ratio_threshold` or `log_prob_threshold`.
|
110 |
-
|
111 |
-
compression_ratio_threshold: gr.Number
|
112 |
-
If the gzip compression ratio is above this value, treat as failed
|
113 |
-
|
114 |
-
vad_filter: gr.Checkbox
|
115 |
-
Enable the voice activity detection (VAD) to filter out parts of the audio
|
116 |
-
without speech. This step is using the Silero VAD model
|
117 |
-
https://github.com/snakers4/silero-vad.
|
118 |
-
|
119 |
-
threshold: gr.Slider
|
120 |
-
This parameter is related with Silero VAD. Speech threshold.
|
121 |
-
Silero VAD outputs speech probabilities for each audio chunk,
|
122 |
-
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
123 |
-
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
124 |
-
|
125 |
-
min_speech_duration_ms: gr.Number
|
126 |
-
This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
|
127 |
-
|
128 |
-
max_speech_duration_s: gr.Number
|
129 |
-
This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
|
130 |
-
than max_speech_duration_s will be split at the timestamp of the last silence that
|
131 |
-
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
132 |
-
split aggressively just before max_speech_duration_s.
|
133 |
-
|
134 |
-
min_silence_duration_ms: gr.Number
|
135 |
-
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
136 |
-
before separating it
|
137 |
-
|
138 |
-
speech_pad_ms: gr.Number
|
139 |
-
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
140 |
-
|
141 |
-
batch_size: gr.Number
|
142 |
-
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
143 |
-
|
144 |
-
is_diarize: gr.Checkbox
|
145 |
-
This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
|
146 |
-
|
147 |
-
hf_token: gr.Textbox
|
148 |
-
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
149 |
-
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
150 |
-
|
151 |
-
diarization_device: gr.Dropdown
|
152 |
-
This parameter is related with whisperx. Device to run diarization model
|
153 |
-
|
154 |
-
length_penalty: gr.Number
|
155 |
-
This parameter is related to faster-whisper. Exponential length penalty constant.
|
156 |
-
|
157 |
-
repetition_penalty: gr.Number
|
158 |
-
This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
|
159 |
-
(set > 1 to penalize).
|
160 |
-
|
161 |
-
no_repeat_ngram_size: gr.Number
|
162 |
-
This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
|
163 |
-
|
164 |
-
prefix: gr.Textbox
|
165 |
-
This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
|
166 |
-
|
167 |
-
suppress_blank: gr.Checkbox
|
168 |
-
This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
|
169 |
-
|
170 |
-
suppress_tokens: gr.Textbox
|
171 |
-
This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
|
172 |
-
of symbols as defined in the model config.json file.
|
173 |
-
|
174 |
-
max_initial_timestamp: gr.Number
|
175 |
-
This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
|
176 |
-
|
177 |
-
word_timestamps: gr.Checkbox
|
178 |
-
This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
|
179 |
-
and dynamic time warping, and include the timestamps for each word in each segment.
|
180 |
-
|
181 |
-
prepend_punctuations: gr.Textbox
|
182 |
-
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
183 |
-
with the next word.
|
184 |
-
|
185 |
-
append_punctuations: gr.Textbox
|
186 |
-
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
187 |
-
with the previous word.
|
188 |
-
|
189 |
-
max_new_tokens: gr.Number
|
190 |
-
This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
|
191 |
-
the maximum will be set by the default max_length.
|
192 |
-
|
193 |
-
chunk_length: gr.Number
|
194 |
-
This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
|
195 |
-
If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
|
196 |
-
|
197 |
-
hallucination_silence_threshold: gr.Number
|
198 |
-
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
199 |
-
(in seconds) when a possible hallucination is detected.
|
200 |
-
|
201 |
-
hotwords: gr.Textbox
|
202 |
-
This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
203 |
-
|
204 |
-
language_detection_threshold: gr.Number
|
205 |
-
This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
|
206 |
-
|
207 |
-
language_detection_segments: gr.Number
|
208 |
-
This parameter is related to faster-whisper. Number of segments to consider for the language detection.
|
209 |
-
|
210 |
-
is_separate_bgm: gr.Checkbox
|
211 |
-
This parameter is related to UVR. Boolean value that determines whether to separate bgm or not.
|
212 |
-
|
213 |
-
uvr_model_size: gr.Dropdown
|
214 |
-
This parameter is related to UVR. UVR model size.
|
215 |
-
|
216 |
-
uvr_device: gr.Dropdown
|
217 |
-
This parameter is related to UVR. Device to run UVR model.
|
218 |
-
|
219 |
-
uvr_segment_size: gr.Number
|
220 |
-
This parameter is related to UVR. Segment size for UVR model.
|
221 |
-
|
222 |
-
uvr_save_file: gr.Checkbox
|
223 |
-
This parameter is related to UVR. Boolean value that determines whether to save the file or not.
|
224 |
-
|
225 |
-
uvr_enable_offload: gr.Checkbox
|
226 |
-
This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
|
227 |
-
after each transcription.
|
228 |
-
"""
|
229 |
-
|
230 |
-
def as_list(self) -> list:
|
231 |
-
"""
|
232 |
-
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
|
233 |
-
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
234 |
-
|
235 |
-
Returns
|
236 |
-
----------
|
237 |
-
A list of Gradio components
|
238 |
-
"""
|
239 |
-
return [getattr(self, f.name) for f in fields(self)]
|
240 |
-
|
241 |
-
@staticmethod
|
242 |
-
def as_value(*args) -> 'WhisperValues':
|
243 |
-
"""
|
244 |
-
To use Whisper parameters in function after Gradio post-processing.
|
245 |
-
See more about Gradio post-processing: : https://www.gradio.app/docs/components
|
246 |
-
|
247 |
-
Returns
|
248 |
-
----------
|
249 |
-
WhisperValues
|
250 |
-
Data class that has values of parameters
|
251 |
-
"""
|
252 |
-
return WhisperValues(*args)
|
253 |
-
|
254 |
-
|
255 |
-
@dataclass
|
256 |
-
class WhisperValues:
|
257 |
-
model_size: str = "large-v2"
|
258 |
-
lang: Optional[str] = None
|
259 |
-
is_translate: bool = False
|
260 |
-
beam_size: int = 5
|
261 |
-
log_prob_threshold: float = -1.0
|
262 |
-
no_speech_threshold: float = 0.6
|
263 |
-
compute_type: str = "float16"
|
264 |
-
best_of: int = 5
|
265 |
-
patience: float = 1.0
|
266 |
-
condition_on_previous_text: bool = True
|
267 |
-
prompt_reset_on_temperature: float = 0.5
|
268 |
-
initial_prompt: Optional[str] = None
|
269 |
-
temperature: float = 0.0
|
270 |
-
compression_ratio_threshold: float = 2.4
|
271 |
-
vad_filter: bool = False
|
272 |
-
threshold: float = 0.5
|
273 |
-
min_speech_duration_ms: int = 250
|
274 |
-
max_speech_duration_s: float = float("inf")
|
275 |
-
min_silence_duration_ms: int = 2000
|
276 |
-
speech_pad_ms: int = 400
|
277 |
-
batch_size: int = 24
|
278 |
-
is_diarize: bool = False
|
279 |
-
hf_token: str = ""
|
280 |
-
diarization_device: str = "cuda"
|
281 |
-
length_penalty: float = 1.0
|
282 |
-
repetition_penalty: float = 1.0
|
283 |
-
no_repeat_ngram_size: int = 0
|
284 |
-
prefix: Optional[str] = None
|
285 |
-
suppress_blank: bool = True
|
286 |
-
suppress_tokens: Optional[str] = "[-1]"
|
287 |
-
max_initial_timestamp: float = 0.0
|
288 |
-
word_timestamps: bool = False
|
289 |
-
prepend_punctuations: Optional[str] = "\"'“¿([{-"
|
290 |
-
append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
|
291 |
-
max_new_tokens: Optional[int] = None
|
292 |
-
chunk_length: Optional[int] = 30
|
293 |
-
hallucination_silence_threshold: Optional[float] = None
|
294 |
-
hotwords: Optional[str] = None
|
295 |
-
language_detection_threshold: Optional[float] = None
|
296 |
-
language_detection_segments: int = 1
|
297 |
-
is_bgm_separate: bool = False
|
298 |
-
uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
|
299 |
-
uvr_device: str = "cuda"
|
300 |
-
uvr_segment_size: int = 256
|
301 |
-
uvr_save_file: bool = False
|
302 |
-
uvr_enable_offload: bool = True
|
303 |
-
"""
|
304 |
-
A data class to use Whisper parameters.
|
305 |
-
"""
|
306 |
-
|
307 |
-
def to_yaml(self) -> Dict:
|
308 |
-
data = {
|
309 |
-
"whisper": {
|
310 |
-
"model_size": self.model_size,
|
311 |
-
"lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang,
|
312 |
-
"is_translate": self.is_translate,
|
313 |
-
"beam_size": self.beam_size,
|
314 |
-
"log_prob_threshold": self.log_prob_threshold,
|
315 |
-
"no_speech_threshold": self.no_speech_threshold,
|
316 |
-
"best_of": self.best_of,
|
317 |
-
"patience": self.patience,
|
318 |
-
"condition_on_previous_text": self.condition_on_previous_text,
|
319 |
-
"prompt_reset_on_temperature": self.prompt_reset_on_temperature,
|
320 |
-
"initial_prompt": None if not self.initial_prompt else self.initial_prompt,
|
321 |
-
"temperature": self.temperature,
|
322 |
-
"compression_ratio_threshold": self.compression_ratio_threshold,
|
323 |
-
"batch_size": self.batch_size,
|
324 |
-
"length_penalty": self.length_penalty,
|
325 |
-
"repetition_penalty": self.repetition_penalty,
|
326 |
-
"no_repeat_ngram_size": self.no_repeat_ngram_size,
|
327 |
-
"prefix": None if not self.prefix else self.prefix,
|
328 |
-
"suppress_blank": self.suppress_blank,
|
329 |
-
"suppress_tokens": self.suppress_tokens,
|
330 |
-
"max_initial_timestamp": self.max_initial_timestamp,
|
331 |
-
"word_timestamps": self.word_timestamps,
|
332 |
-
"prepend_punctuations": self.prepend_punctuations,
|
333 |
-
"append_punctuations": self.append_punctuations,
|
334 |
-
"max_new_tokens": self.max_new_tokens,
|
335 |
-
"chunk_length": self.chunk_length,
|
336 |
-
"hallucination_silence_threshold": self.hallucination_silence_threshold,
|
337 |
-
"hotwords": None if not self.hotwords else self.hotwords,
|
338 |
-
"language_detection_threshold": self.language_detection_threshold,
|
339 |
-
"language_detection_segments": self.language_detection_segments,
|
340 |
-
},
|
341 |
-
"vad": {
|
342 |
-
"vad_filter": self.vad_filter,
|
343 |
-
"threshold": self.threshold,
|
344 |
-
"min_speech_duration_ms": self.min_speech_duration_ms,
|
345 |
-
"max_speech_duration_s": self.max_speech_duration_s,
|
346 |
-
"min_silence_duration_ms": self.min_silence_duration_ms,
|
347 |
-
"speech_pad_ms": self.speech_pad_ms,
|
348 |
-
},
|
349 |
-
"diarization": {
|
350 |
-
"is_diarize": self.is_diarize,
|
351 |
-
"hf_token": self.hf_token
|
352 |
-
},
|
353 |
-
"bgm_separation": {
|
354 |
-
"is_separate_bgm": self.is_bgm_separate,
|
355 |
-
"model_size": self.uvr_model_size,
|
356 |
-
"segment_size": self.uvr_segment_size,
|
357 |
-
"save_file": self.uvr_save_file,
|
358 |
-
"enable_offload": self.uvr_enable_offload
|
359 |
-
},
|
360 |
-
}
|
361 |
-
return data
|
362 |
-
|
363 |
-
def as_list(self) -> list:
|
364 |
-
"""
|
365 |
-
Converts the data class attributes into a list
|
366 |
-
|
367 |
-
Returns
|
368 |
-
----------
|
369 |
-
A list of Whisper parameters
|
370 |
-
"""
|
371 |
-
return [getattr(self, f.name) for f in fields(self)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_bgm_separation.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from modules.utils.paths import *
|
2 |
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
-
from modules.whisper.
|
4 |
from test_config import *
|
5 |
from test_transcription import download_file, test_transcribe
|
6 |
|
@@ -17,9 +17,9 @@ import os
|
|
17 |
@pytest.mark.parametrize(
|
18 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
19 |
[
|
20 |
-
(
|
21 |
-
(
|
22 |
-
(
|
23 |
]
|
24 |
)
|
25 |
def test_bgm_separation_pipeline(
|
@@ -38,9 +38,9 @@ def test_bgm_separation_pipeline(
|
|
38 |
@pytest.mark.parametrize(
|
39 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
40 |
[
|
41 |
-
(
|
42 |
-
(
|
43 |
-
(
|
44 |
]
|
45 |
)
|
46 |
def test_bgm_separation_with_vad_pipeline(
|
|
|
1 |
from modules.utils.paths import *
|
2 |
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
+
from modules.whisper.data_classes import *
|
4 |
from test_config import *
|
5 |
from test_transcription import download_file, test_transcribe
|
6 |
|
|
|
17 |
@pytest.mark.parametrize(
|
18 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
19 |
[
|
20 |
+
(WhisperImpl.WHISPER.value, False, True, False),
|
21 |
+
(WhisperImpl.FASTER_WHISPER.value, False, True, False),
|
22 |
+
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False)
|
23 |
]
|
24 |
)
|
25 |
def test_bgm_separation_pipeline(
|
|
|
38 |
@pytest.mark.parametrize(
|
39 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
40 |
[
|
41 |
+
(WhisperImpl.WHISPER.value, True, True, False),
|
42 |
+
(WhisperImpl.FASTER_WHISPER.value, True, True, False),
|
43 |
+
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False)
|
44 |
]
|
45 |
)
|
46 |
def test_bgm_separation_with_vad_pipeline(
|
tests/test_config.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
7 |
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
8 |
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
|
9 |
-
TEST_WHISPER_MODEL = "tiny"
|
10 |
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
|
11 |
TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
|
12 |
TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
|
|
|
6 |
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
7 |
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
8 |
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
|
9 |
+
TEST_WHISPER_MODEL = "tiny.en"
|
10 |
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
|
11 |
TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
|
12 |
TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
|
tests/test_diarization.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from modules.utils.paths import *
|
2 |
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
-
from modules.whisper.
|
4 |
from test_config import *
|
5 |
from test_transcription import download_file, test_transcribe
|
6 |
|
@@ -16,9 +16,9 @@ import os
|
|
16 |
@pytest.mark.parametrize(
|
17 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
18 |
[
|
19 |
-
(
|
20 |
-
(
|
21 |
-
(
|
22 |
]
|
23 |
)
|
24 |
def test_diarization_pipeline(
|
|
|
1 |
from modules.utils.paths import *
|
2 |
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
+
from modules.whisper.data_classes import *
|
4 |
from test_config import *
|
5 |
from test_transcription import download_file, test_transcribe
|
6 |
|
|
|
16 |
@pytest.mark.parametrize(
|
17 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
18 |
[
|
19 |
+
(WhisperImpl.WHISPER.value, False, False, True),
|
20 |
+
(WhisperImpl.FASTER_WHISPER.value, False, False, True),
|
21 |
+
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True)
|
22 |
]
|
23 |
)
|
24 |
def test_diarization_pipeline(
|
tests/test_transcription.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from modules.whisper.whisper_factory import WhisperFactory
|
2 |
-
from modules.whisper.
|
3 |
from modules.utils.paths import WEBUI_DIR
|
4 |
from test_config import *
|
5 |
|
@@ -12,9 +12,9 @@ import os
|
|
12 |
@pytest.mark.parametrize(
|
13 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
14 |
[
|
15 |
-
(
|
16 |
-
(
|
17 |
-
(
|
18 |
]
|
19 |
)
|
20 |
def test_transcribe(
|
@@ -37,14 +37,22 @@ def test_transcribe(
|
|
37 |
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
|
38 |
)
|
39 |
|
40 |
-
hparams =
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
subtitle_str, file_path = whisper_inferencer.transcribe_file(
|
50 |
[audio_path],
|
|
|
1 |
from modules.whisper.whisper_factory import WhisperFactory
|
2 |
+
from modules.whisper.data_classes import *
|
3 |
from modules.utils.paths import WEBUI_DIR
|
4 |
from test_config import *
|
5 |
|
|
|
12 |
@pytest.mark.parametrize(
|
13 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
14 |
[
|
15 |
+
(WhisperImpl.WHISPER.value, False, False, False),
|
16 |
+
(WhisperImpl.FASTER_WHISPER.value, False, False, False),
|
17 |
+
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
|
18 |
]
|
19 |
)
|
20 |
def test_transcribe(
|
|
|
37 |
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
|
38 |
)
|
39 |
|
40 |
+
hparams = TranscriptionPipelineParams(
|
41 |
+
whisper=WhisperParams(
|
42 |
+
model_size=TEST_WHISPER_MODEL,
|
43 |
+
compute_type=whisper_inferencer.current_compute_type
|
44 |
+
),
|
45 |
+
vad=VadParams(
|
46 |
+
vad_filter=vad_filter
|
47 |
+
),
|
48 |
+
bgm_separation=BGMSeparationParams(
|
49 |
+
is_separate_bgm=bgm_separation,
|
50 |
+
enable_offload=True
|
51 |
+
),
|
52 |
+
diarization=DiarizationParams(
|
53 |
+
is_diarize=diarization
|
54 |
+
),
|
55 |
+
).to_list()
|
56 |
|
57 |
subtitle_str, file_path = whisper_inferencer.transcribe_file(
|
58 |
[audio_path],
|
tests/test_vad.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from modules.utils.paths import *
|
2 |
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
-
from modules.whisper.
|
4 |
from test_config import *
|
5 |
from test_transcription import download_file, test_transcribe
|
6 |
|
@@ -12,9 +12,9 @@ import os
|
|
12 |
@pytest.mark.parametrize(
|
13 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
14 |
[
|
15 |
-
(
|
16 |
-
(
|
17 |
-
(
|
18 |
]
|
19 |
)
|
20 |
def test_vad_pipeline(
|
|
|
1 |
from modules.utils.paths import *
|
2 |
from modules.whisper.whisper_factory import WhisperFactory
|
3 |
+
from modules.whisper.data_classes import *
|
4 |
from test_config import *
|
5 |
from test_transcription import download_file, test_transcribe
|
6 |
|
|
|
12 |
@pytest.mark.parametrize(
|
13 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
14 |
[
|
15 |
+
(WhisperImpl.WHISPER.value, True, False, False),
|
16 |
+
(WhisperImpl.FASTER_WHISPER.value, True, False, False),
|
17 |
+
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False)
|
18 |
]
|
19 |
)
|
20 |
def test_vad_pipeline(
|