Spaces:
Running
Running
jhj0517
commited on
Commit
·
89f23c0
1
Parent(s):
2cbdd50
Fix wrong caching parameters and resolve circular imports
Browse files
modules/translation/nllb_inference.py
CHANGED
@@ -3,10 +3,10 @@ import gradio as gr
|
|
3 |
import os
|
4 |
|
5 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
|
6 |
-
|
7 |
|
8 |
|
9 |
-
class NLLBInference(TranslationBase):
|
10 |
def __init__(self,
|
11 |
model_dir: str = NLLB_MODELS_DIR,
|
12 |
output_dir: str = TRANSLATION_OUTPUT_DIR
|
|
|
3 |
import os
|
4 |
|
5 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
|
6 |
+
import modules.translation.translation_base as base
|
7 |
|
8 |
|
9 |
+
class NLLBInference(base.TranslationBase):
|
10 |
def __init__(self,
|
11 |
model_dir: str = NLLB_MODELS_DIR,
|
12 |
output_dir: str = TRANSLATION_OUTPUT_DIR
|
modules/translation/translation_base.py
CHANGED
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
|
5 |
from typing import List
|
6 |
from datetime import datetime
|
7 |
|
|
|
8 |
from modules.whisper.whisper_parameter import *
|
9 |
from modules.utils.subtitle_manager import *
|
10 |
from modules.utils.files_manager import load_yaml, save_yaml
|
@@ -166,11 +167,17 @@ class TranslationBase(ABC):
|
|
166 |
tgt_lang: str,
|
167 |
max_length: int,
|
168 |
add_timestamp: bool):
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
170 |
cached_params["translation"]["nllb"] = {
|
171 |
"model_size": model_size,
|
172 |
-
"source_lang": src_lang,
|
173 |
-
"target_lang": tgt_lang,
|
174 |
"max_length": max_length,
|
175 |
}
|
176 |
cached_params["translation"]["add_timestamp"] = add_timestamp
|
|
|
5 |
from typing import List
|
6 |
from datetime import datetime
|
7 |
|
8 |
+
import modules.translation.nllb_inference as nllb
|
9 |
from modules.whisper.whisper_parameter import *
|
10 |
from modules.utils.subtitle_manager import *
|
11 |
from modules.utils.files_manager import load_yaml, save_yaml
|
|
|
167 |
tgt_lang: str,
|
168 |
max_length: int,
|
169 |
add_timestamp: bool):
|
170 |
+
def validate_lang(lang: str):
|
171 |
+
if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
|
172 |
+
flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
|
173 |
+
return flipped[lang]
|
174 |
+
return lang
|
175 |
+
|
176 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
177 |
cached_params["translation"]["nllb"] = {
|
178 |
"model_size": model_size,
|
179 |
+
"source_lang": validate_lang(src_lang),
|
180 |
+
"target_lang": validate_lang(tgt_lang),
|
181 |
"max_length": max_length,
|
182 |
}
|
183 |
cached_params["translation"]["add_timestamp"] = add_timestamp
|