jhj0517 commited on
Commit
7d9eec3
·
1 Parent(s): f96621b

Apply constants

Browse files
app.py CHANGED
@@ -3,6 +3,8 @@ import argparse
3
  import gradio as gr
4
  import yaml
5
 
 
 
6
  from modules.whisper.whisper_factory import WhisperFactory
7
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
8
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
@@ -33,9 +35,7 @@ class App:
33
  self.deepl_api = DeepLAPI(
34
  output_dir=os.path.join(self.args.output_dir, "translations")
35
  )
36
-
37
- default_param_path = os.path.join("configs", "default_parameters.yaml")
38
- with open(default_param_path, 'r', encoding='utf-8') as file:
39
  self.default_params = yaml.safe_load(file)
40
 
41
  def create_whisper_parameters(self):
@@ -290,7 +290,7 @@ class App:
290
  cb_deepl_ispro, cb_timestamp],
291
  outputs=[tb_indicator, files_subtitles])
292
 
293
- btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
294
  inputs=None,
295
  outputs=None)
296
 
@@ -321,7 +321,7 @@ class App:
321
  nb_max_length, cb_timestamp],
322
  outputs=[tb_indicator, files_subtitles])
323
 
324
- btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
325
  inputs=None,
326
  outputs=None)
327
 
@@ -369,18 +369,18 @@ parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme
369
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
370
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
371
  parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
372
- parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"),
373
  help='Directory path of the whisper model')
374
- parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"),
375
  help='Directory path of the faster-whisper model')
376
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
377
- default=os.path.join("models", "Whisper", "insanely-fast-whisper"),
378
  help='Directory path of the insanely-fast-whisper model')
379
- parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"),
380
  help='Directory path of the diarization model')
381
- parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"),
382
  help='Directory path of the Facebook NLLB model')
383
- parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
384
  _args = parser.parse_args()
385
 
386
  if __name__ == "__main__":
 
3
  import gradio as gr
4
  import yaml
5
 
6
+ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
7
+ INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH)
8
  from modules.whisper.whisper_factory import WhisperFactory
9
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
10
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
 
35
  self.deepl_api = DeepLAPI(
36
  output_dir=os.path.join(self.args.output_dir, "translations")
37
  )
38
+ with open(DEFAULT_PARAMETERS_CONFIG_PATH, 'r', encoding='utf-8') as file:
 
 
39
  self.default_params = yaml.safe_load(file)
40
 
41
  def create_whisper_parameters(self):
 
290
  cb_deepl_ispro, cb_timestamp],
291
  outputs=[tb_indicator, files_subtitles])
292
 
293
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
294
  inputs=None,
295
  outputs=None)
296
 
 
321
  nb_max_length, cb_timestamp],
322
  outputs=[tb_indicator, files_subtitles])
323
 
324
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
325
  inputs=None,
326
  outputs=None)
327
 
 
369
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
370
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
371
  parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
372
+ parser.add_argument('--whisper_model_dir', type=str, default=WHISPER_MODELS_DIR,
373
  help='Directory path of the whisper model')
374
+ parser.add_argument('--faster_whisper_model_dir', type=str, default=FASTER_WHISPER_MODELS_DIR,
375
  help='Directory path of the faster-whisper model')
376
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
377
+ default=INSANELY_FAST_WHISPER_MODELS_DIR,
378
  help='Directory path of the insanely-fast-whisper model')
379
+ parser.add_argument('--diarization_model_dir', type=str, default=DIARIZATION_MODELS_DIR,
380
  help='Directory path of the diarization model')
381
+ parser.add_argument('--nllb_model_dir', type=str, default=NLLB_MODELS_DIR,
382
  help='Directory path of the Facebook NLLB model')
383
+ parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='Directory path of the outputs')
384
  _args = parser.parse_args()
385
 
386
  if __name__ == "__main__":
modules/diarize/diarize_pipeline.py CHANGED
@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
7
  from typing import Optional, Union
8
  import torch
9
 
 
10
  from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
11
 
12
 
@@ -14,7 +15,7 @@ class DiarizationPipeline:
14
  def __init__(
15
  self,
16
  model_name="pyannote/speaker-diarization-3.1",
17
- cache_dir: str = os.path.join("models", "Diarization"),
18
  use_auth_token=None,
19
  device: Optional[Union[str, torch.device]] = "cpu",
20
  ):
 
7
  from typing import Optional, Union
8
  import torch
9
 
10
+ from modules.utils.paths import DIARIZATION_MODELS_DIR
11
  from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
12
 
13
 
 
15
  def __init__(
16
  self,
17
  model_name="pyannote/speaker-diarization-3.1",
18
+ cache_dir: str = DIARIZATION_MODELS_DIR,
19
  use_auth_token=None,
20
  device: Optional[Union[str, torch.device]] = "cpu",
21
  ):
modules/diarize/diarizer.py CHANGED
@@ -5,13 +5,14 @@ import numpy as np
5
  import time
6
  import logging
7
 
 
8
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
9
  from modules.diarize.audio_loader import load_audio
10
 
11
 
12
  class Diarizer:
13
  def __init__(self,
14
- model_dir: str = os.path.join("models", "Diarization")
15
  ):
16
  self.device = self.get_device()
17
  self.available_device = self.get_available_device()
 
5
  import time
6
  import logging
7
 
8
+ from modules.utils.paths import DIARIZATION_MODELS_DIR
9
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
10
  from modules.diarize.audio_loader import load_audio
11
 
12
 
13
  class Diarizer:
14
  def __init__(self,
15
+ model_dir: str = DIARIZATION_MODELS_DIR
16
  ):
17
  self.device = self.get_device()
18
  self.available_device = self.get_available_device()
modules/translation/deepl_api.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  from datetime import datetime
5
  import gradio as gr
6
 
 
7
  from modules.utils.subtitle_manager import *
8
 
9
  """
@@ -83,7 +84,7 @@ DEEPL_AVAILABLE_SOURCE_LANGS = {
83
 
84
  class DeepLAPI:
85
  def __init__(self,
86
- output_dir: str = os.path.join("outputs", "translations")
87
  ):
88
  self.api_interval = 1
89
  self.max_text_batch_size = 50
 
4
  from datetime import datetime
5
  import gradio as gr
6
 
7
+ from modules.utils.paths import TRANSLATION_OUTPUT_DIR
8
  from modules.utils.subtitle_manager import *
9
 
10
  """
 
84
 
85
  class DeepLAPI:
86
  def __init__(self,
87
+ output_dir: str = TRANSLATION_OUTPUT_DIR
88
  ):
89
  self.api_interval = 1
90
  self.max_text_batch_size = 50
modules/translation/nllb_inference.py CHANGED
@@ -2,13 +2,14 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
  import os
4
 
 
5
  from modules.translation.translation_base import TranslationBase
6
 
7
 
8
  class NLLBInference(TranslationBase):
9
  def __init__(self,
10
- model_dir: str = os.path.join("models", "NLLB"),
11
- output_dir: str = os.path.join("outputs", "translations")
12
  ):
13
  super().__init__(
14
  model_dir=model_dir,
 
2
  import gradio as gr
3
  import os
4
 
5
+ from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
  from modules.translation.translation_base import TranslationBase
7
 
8
 
9
  class NLLBInference(TranslationBase):
10
  def __init__(self,
11
+ model_dir: str = NLLB_MODELS_DIR,
12
+ output_dir: str = TRANSLATION_OUTPUT_DIR
13
  ):
14
  super().__init__(
15
  model_dir=model_dir,
modules/utils/paths.py CHANGED
@@ -8,7 +8,7 @@ INSANELY_FAST_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "insanely-fa
8
  NLLB_MODELS_DIR = os.path.join(MODELS_DIR, "NLLB")
9
  DIARIZATION_MODELS_DIR = os.path.join(MODELS_DIR, "Diarization")
10
  CONFIGS_DIR = os.path.join(WEBUI_DIR, "configs")
11
- DEFAULT_PARAMETERS_PATH = os.path.join(CONFIGS_DIR, "default_parameters.yaml")
12
  OUTPUT_DIR = os.path.join(WEBUI_DIR, "outputs")
13
  TRANSLATION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "translations")
14
 
 
8
  NLLB_MODELS_DIR = os.path.join(MODELS_DIR, "NLLB")
9
  DIARIZATION_MODELS_DIR = os.path.join(MODELS_DIR, "Diarization")
10
  CONFIGS_DIR = os.path.join(WEBUI_DIR, "configs")
11
+ DEFAULT_PARAMETERS_CONFIG_PATH = os.path.join(CONFIGS_DIR, "default_parameters.yaml")
12
  OUTPUT_DIR = os.path.join(WEBUI_DIR, "outputs")
13
  TRANSLATION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "translations")
14
 
modules/whisper/faster_whisper_inference.py CHANGED
@@ -11,15 +11,16 @@ import whisper
11
  import gradio as gr
12
  from argparse import Namespace
13
 
 
14
  from modules.whisper.whisper_parameter import *
15
  from modules.whisper.whisper_base import WhisperBase
16
 
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
21
- diarization_model_dir: str = os.path.join("models", "Diarization"),
22
- output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
@@ -163,14 +164,12 @@ class FasterWhisperInference(WhisperBase):
163
  wrong_dirs = [".locks"]
164
  existing_models = list(set(existing_models) - set(wrong_dirs))
165
 
166
- webui_dir = os.getcwd()
167
-
168
  for model_name in existing_models:
169
  if faster_whisper_prefix in model_name:
170
  model_name = model_name[len(faster_whisper_prefix):]
171
 
172
  if model_name not in whisper.available_models():
173
- model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
174
  return model_paths
175
 
176
  @staticmethod
 
11
  import gradio as gr
12
  from argparse import Namespace
13
 
14
+ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.whisper_parameter import *
16
  from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
  class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
+ model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
23
+ output_dir: str = OUTPUT_DIR,
24
  ):
25
  super().__init__(
26
  model_dir=model_dir,
 
164
  wrong_dirs = [".locks"]
165
  existing_models = list(set(existing_models) - set(wrong_dirs))
166
 
 
 
167
  for model_name in existing_models:
168
  if faster_whisper_prefix in model_name:
169
  model_name = model_name[len(faster_whisper_prefix):]
170
 
171
  if model_name not in whisper.available_models():
172
+ model_paths[model_name] = os.path.join(self.model_dir, model_name)
173
  return model_paths
174
 
175
  @staticmethod
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -11,15 +11,16 @@ import whisper
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
  from argparse import Namespace
13
 
 
14
  from modules.whisper.whisper_parameter import *
15
  from modules.whisper.whisper_base import WhisperBase
16
 
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
21
- diarization_model_dir: str = os.path.join("models", "Diarization"),
22
- output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
 
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
  from argparse import Namespace
13
 
14
+ from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.whisper_parameter import *
16
  from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
  class InsanelyFastWhisperInference(WhisperBase):
20
  def __init__(self,
21
+ model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
22
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
23
+ output_dir: str = OUTPUT_DIR,
24
  ):
25
  super().__init__(
26
  model_dir=model_dir,
modules/whisper/whisper_Inference.py CHANGED
@@ -7,15 +7,16 @@ import torch
7
  import os
8
  from argparse import Namespace
9
 
 
10
  from modules.whisper.whisper_base import WhisperBase
11
  from modules.whisper.whisper_parameter import *
12
 
13
 
14
  class WhisperInference(WhisperBase):
15
  def __init__(self,
16
- model_dir: str = os.path.join("models", "Whisper"),
17
- diarization_model_dir: str = os.path.join("models", "Diarization"),
18
- output_dir: str = os.path.join("outputs"),
19
  ):
20
  super().__init__(
21
  model_dir=model_dir,
 
7
  import os
8
  from argparse import Namespace
9
 
10
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
11
  from modules.whisper.whisper_base import WhisperBase
12
  from modules.whisper.whisper_parameter import *
13
 
14
 
15
  class WhisperInference(WhisperBase):
16
  def __init__(self,
17
+ model_dir: str = WHISPER_MODELS_DIR,
18
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
19
+ output_dir: str = OUTPUT_DIR,
20
  ):
21
  super().__init__(
22
  model_dir=model_dir,
modules/whisper/whisper_base.py CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
 
12
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
13
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
14
  from modules.utils.files_manager import get_media_files, format_gradio_files
@@ -19,9 +20,9 @@ from modules.vad.silero_vad import SileroVAD
19
 
20
  class WhisperBase(ABC):
21
  def __init__(self,
22
- model_dir: str = os.path.join("models", "Whisper"),
23
- diarization_model_dir: str = os.path.join("models", "Diarization"),
24
- output_dir: str = os.path.join("outputs"),
25
  ):
26
  self.model_dir = model_dir
27
  self.output_dir = output_dir
 
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
12
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
15
  from modules.utils.files_manager import get_media_files, format_gradio_files
 
20
 
21
  class WhisperBase(ABC):
22
  def __init__(self,
23
+ model_dir: str = WHISPER_MODELS_DIR,
24
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
25
+ output_dir: str = OUTPUT_DIR,
26
  ):
27
  self.model_dir = model_dir
28
  self.output_dir = output_dir
modules/whisper/whisper_factory.py CHANGED
@@ -1,6 +1,8 @@
1
  from typing import Optional
2
  import os
3
 
 
 
4
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
5
  from modules.whisper.whisper_Inference import WhisperInference
6
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
@@ -11,11 +13,11 @@ class WhisperFactory:
11
  @staticmethod
12
  def create_whisper_inference(
13
  whisper_type: str,
14
- whisper_model_dir: str = os.path.join("models", "Whisper"),
15
- faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
16
- insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
17
- diarization_model_dir: str = os.path.join("models", "Diarization"),
18
- output_dir: str = os.path.join("outputs"),
19
  ) -> "WhisperBase":
20
  """
21
  Create a whisper inference class based on the provided whisper_type.
 
1
  from typing import Optional
2
  import os
3
 
4
+ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR,
5
+ INSANELY_FAST_WHISPER_MODELS_DIR, WHISPER_MODELS_DIR)
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.whisper_Inference import WhisperInference
8
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
 
13
  @staticmethod
14
  def create_whisper_inference(
15
  whisper_type: str,
16
+ whisper_model_dir: str = WHISPER_MODELS_DIR,
17
+ faster_whisper_model_dir: str = FASTER_WHISPER_MODELS_DIR,
18
+ insanely_fast_whisper_model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
19
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
20
+ output_dir: str = OUTPUT_DIR,
21
  ) -> "WhisperBase":
22
  """
23
  Create a whisper inference class based on the provided whisper_type.