jhj0517 commited on
Commit
19c3dbd
·
1 Parent(s): abb1ca2

Apply default values with yaml

Browse files
Files changed (1) hide show
  1. app.py +56 -38
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import argparse
3
  import gradio as gr
 
4
 
5
  from modules.whisper.whisper_factory import WhisperFactory
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
@@ -33,102 +34,119 @@ class App:
33
  output_dir=os.path.join(self.args.output_dir, "translations")
34
  )
35
 
 
 
 
 
36
  def create_whisper_parameters(self):
 
 
 
 
37
  with gr.Row():
38
- dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
39
  label="Model")
40
  dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
41
- value="Automatic Detection", label="Language")
42
  dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
43
  with gr.Row():
44
- cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
 
45
  with gr.Row():
46
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
47
  interactive=True)
48
  with gr.Accordion("Advanced Parameters", open=False):
49
- nb_beam_size = gr.Number(label="Beam Size", value=5, precision=0, interactive=True,
50
  info="Beam size to use for decoding.")
51
- nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True,
52
  info="If the average log probability over sampled tokens is below this value, treat as failed.")
53
- nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True,
54
  info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
55
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
56
  value=self.whisper_inf.current_compute_type, interactive=True,
57
  info="Select the type of computation to perform.")
58
- nb_best_of = gr.Number(label="Best Of", value=5, interactive=True,
59
  info="Number of candidates when sampling with non-zero temperature.")
60
- nb_patience = gr.Number(label="Patience", value=1, interactive=True,
61
  info="Beam search patience factor.")
62
- cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True,
63
  interactive=True,
64
  info="Condition on previous text during decoding.")
65
- sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=0.5,
66
  minimum=0, maximum=1, step=0.01, interactive=True,
67
  info="Resets prompt if temperature is above this value."
68
  " Arg has effect only if 'Condition On Previous Text' is True.")
69
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
70
  info="Initial prompt to use for decoding.")
71
- sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True,
 
72
  info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
73
- nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True,
 
74
  info="If the gzip compression ratio is above this value, treat as failed.")
75
  with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
76
- nb_length_penalty = gr.Number(label="Length Penalty", value=1,
77
  info="Exponential length penalty constant.")
78
- nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=1,
79
  info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
80
- nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=0, precision=0,
 
81
  info="Prevent repetitions of n-grams with this size (set 0 to disable).")
82
- tb_prefix = gr.Textbox(label="Prefix", value=lambda: None,
83
  info="Optional text to provide as a prefix for the first window.")
84
- cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=True,
85
  info="Suppress blank outputs at the beginning of the sampling.")
86
- tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value="[-1]",
87
  info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
88
- nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=1.0,
89
  info="The initial timestamp cannot be later than this.")
90
- cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=False,
91
  info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
92
- tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value="\"'“¿([{-",
93
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
94
- tb_append_punctuations = gr.Textbox(label="Append Punctuations", value="\"'.。,,!!??::”)]}、",
95
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
96
- nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: None, precision=0,
 
97
  info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
98
- nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: None, precision=0,
 
99
  info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
100
  nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
101
- value=lambda: None,
102
  info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
103
- tb_hotwords = gr.Textbox(label="Hotwords", value=None,
104
  info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
105
- nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=None,
106
  info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
107
- nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=1, precision=0,
 
108
  info="Number of segments to consider for the language detection.")
109
  with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
110
- nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
111
- nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
 
112
 
113
  with gr.Accordion("VAD", open=False):
114
- cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
115
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5,
 
116
  info="Lower it to be more sensitive to small sounds.")
117
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250,
118
  info="Final speech chunks shorter than this time are thrown out")
119
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999,
120
  info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
121
  " than this time will be split at the timestamp of the last silence that"
122
  " lasts more than 100ms (if any), to prevent aggressive cutting.")
123
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000,
124
  info="In the end of each speech chunk wait for this time"
125
  " before separating it")
126
- nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400,
127
  info="Final speech chunks are padded by this time each side")
128
 
129
  with gr.Accordion("Diarization", open=False):
130
- cb_diarize = gr.Checkbox(label="Enable Diarization")
131
- tb_hf_token = gr.Text(label="HuggingFace Token", value="",
132
  info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
133
  dd_diarization_device = gr.Dropdown(label="Device",
134
  choices=self.whisper_inf.diarizer.get_available_device(),
 
1
  import os
2
  import argparse
3
  import gradio as gr
4
+ import yaml
5
 
6
  from modules.whisper.whisper_factory import WhisperFactory
7
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
 
34
  output_dir=os.path.join(self.args.output_dir, "translations")
35
  )
36
 
37
+ default_param_path = os.path.join("configs", "default_parameters.yaml")
38
+ with open(default_param_path, 'r', encoding='utf-8') as file:
39
+ self.default_params = yaml.safe_load(file)
40
+
41
  def create_whisper_parameters(self):
42
+ whisper_params = self.default_params["whisper"]
43
+ vad_params = self.default_params["vad"]
44
+ diarization_params = self.default_params["diarization"]
45
+
46
  with gr.Row():
47
+ dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value=whisper_params["model_size"],
48
  label="Model")
49
  dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
50
+ value=whisper_params["lang"], label="Language")
51
  dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
52
  with gr.Row():
53
+ cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label="Translate to English?",
54
+ interactive=True)
55
  with gr.Row():
56
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
57
  interactive=True)
58
  with gr.Accordion("Advanced Parameters", open=False):
59
+ nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True,
60
  info="Beam size to use for decoding.")
61
+ nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=whisper_params["log_prob_threshold"], interactive=True,
62
  info="If the average log probability over sampled tokens is below this value, treat as failed.")
63
+ nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"], interactive=True,
64
  info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
65
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
66
  value=self.whisper_inf.current_compute_type, interactive=True,
67
  info="Select the type of computation to perform.")
68
+ nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
69
  info="Number of candidates when sampling with non-zero temperature.")
70
+ nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
71
  info="Beam search patience factor.")
72
+ cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=whisper_params["condition_on_previous_text"],
73
  interactive=True,
74
  info="Condition on previous text during decoding.")
75
+ sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=whisper_params["prompt_reset_on_temperature"],
76
  minimum=0, maximum=1, step=0.01, interactive=True,
77
  info="Resets prompt if temperature is above this value."
78
  " Arg has effect only if 'Condition On Previous Text' is True.")
79
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
80
  info="Initial prompt to use for decoding.")
81
+ sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
82
+ step=0.01, maximum=1.0, interactive=True,
83
  info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
84
+ nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"],
85
+ interactive=True,
86
  info="If the gzip compression ratio is above this value, treat as failed.")
87
  with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
88
+ nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
89
  info="Exponential length penalty constant.")
90
+ nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=whisper_params["repetition_penalty"],
91
  info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
92
+ nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=whisper_params["no_repeat_ngram_size"],
93
+ precision=0,
94
  info="Prevent repetitions of n-grams with this size (set 0 to disable).")
95
+ tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
96
  info="Optional text to provide as a prefix for the first window.")
97
+ cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
98
  info="Suppress blank outputs at the beginning of the sampling.")
99
+ tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
100
  info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
101
+ nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=whisper_params["max_initial_timestamp"],
102
  info="The initial timestamp cannot be later than this.")
103
+ cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
104
  info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
105
+ tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value=whisper_params["prepend_punctuations"],
106
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
107
+ tb_append_punctuations = gr.Textbox(label="Append Punctuations", value=whisper_params["append_punctuations"],
108
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
109
+ nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
110
+ precision=0,
111
  info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
112
+ nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: whisper_params["chunk_length"],
113
+ precision=0,
114
  info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
115
  nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
116
+ value=lambda: whisper_params["hallucination_silence_threshold"],
117
  info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
118
+ tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
119
  info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
120
+ nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=lambda: whisper_params["language_detection_threshold"],
121
  info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
122
+ nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=lambda: whisper_params["language_detection_segments"],
123
+ precision=0,
124
  info="Number of segments to consider for the language detection.")
125
  with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
126
+ nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=whisper_params["chunk_length_s"],
127
+ precision=0)
128
+ nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
129
 
130
  with gr.Accordion("VAD", open=False):
131
+ cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"],
132
+ interactive=True)
133
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=vad_params["threshold"],
134
  info="Lower it to be more sensitive to small sounds.")
135
+ nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=vad_params["min_speech_duration_ms"],
136
  info="Final speech chunks shorter than this time are thrown out")
137
+ nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=vad_params["max_speech_duration_s"],
138
  info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
139
  " than this time will be split at the timestamp of the last silence that"
140
  " lasts more than 100ms (if any), to prevent aggressive cutting.")
141
+ nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=vad_params["min_silence_duration_ms"],
142
  info="In the end of each speech chunk wait for this time"
143
  " before separating it")
144
+ nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
145
  info="Final speech chunks are padded by this time each side")
146
 
147
  with gr.Accordion("Diarization", open=False):
148
+ cb_diarize = gr.Checkbox(label="Enable Diarization", value=diarization_params["is_diarize"])
149
+ tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"],
150
  info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
151
  dd_diarization_device = gr.Dropdown(label="Device",
152
  choices=self.whisper_inf.diarizer.get_available_device(),