jhj0517 commited on
Commit
89f23c0
·
1 Parent(s): 2cbdd50

Fix wrong caching parameters and resolve circular imports

Browse files
modules/translation/nllb_inference.py CHANGED
@@ -3,10 +3,10 @@ import gradio as gr
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
- from modules.translation.translation_base import TranslationBase
7
 
8
 
9
- class NLLBInference(TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
 
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
+ import modules.translation.translation_base as base
7
 
8
 
9
+ class NLLBInference(base.TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
modules/translation/translation_base.py CHANGED
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
5
  from typing import List
6
  from datetime import datetime
7
 
 
8
  from modules.whisper.whisper_parameter import *
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
@@ -166,11 +167,17 @@ class TranslationBase(ABC):
166
  tgt_lang: str,
167
  max_length: int,
168
  add_timestamp: bool):
 
 
 
 
 
 
169
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
170
  cached_params["translation"]["nllb"] = {
171
  "model_size": model_size,
172
- "source_lang": src_lang,
173
- "target_lang": tgt_lang,
174
  "max_length": max_length,
175
  }
176
  cached_params["translation"]["add_timestamp"] = add_timestamp
 
5
  from typing import List
6
  from datetime import datetime
7
 
8
+ import modules.translation.nllb_inference as nllb
9
  from modules.whisper.whisper_parameter import *
10
  from modules.utils.subtitle_manager import *
11
  from modules.utils.files_manager import load_yaml, save_yaml
 
167
  tgt_lang: str,
168
  max_length: int,
169
  add_timestamp: bool):
170
+ def validate_lang(lang: str):
171
+ if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
172
+ flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
173
+ return flipped[lang]
174
+ return lang
175
+
176
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
177
  cached_params["translation"]["nllb"] = {
178
  "model_size": model_size,
179
+ "source_lang": validate_lang(src_lang),
180
+ "target_lang": validate_lang(tgt_lang),
181
  "max_length": max_length,
182
  }
183
  cached_params["translation"]["add_timestamp"] = add_timestamp