jhj0517 commited on
Commit
9e5ed74
·
1 Parent(s): be000f4

Enable fintuned models

Browse files
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -32,9 +32,7 @@ class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
32
  self.model_dir = model_dir
33
  os.makedirs(self.model_dir, exist_ok=True)
34
 
35
- openai_models = whisper.available_models()
36
- distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
37
- self.available_models = openai_models + distil_models
38
 
39
  def transcribe(self,
40
  audio: Union[str, np.ndarray, torch.Tensor],
@@ -146,31 +144,26 @@ class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
146
  model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
147
  )
148
 
149
- @staticmethod
150
- def format_result(
151
- transcribed_result: dict
152
- ) -> List[dict]:
153
  """
154
- Format the transcription result of insanely_fast_whisper as the same with other implementation.
155
-
156
- Parameters
157
- ----------
158
- transcribed_result: dict
159
- Transcription result of the insanely_fast_whisper
160
 
161
  Returns
162
  ----------
163
- result: List[dict]
164
- Formatted result as the same with other implementation
165
  """
166
- result = transcribed_result["chunks"]
167
- for item in result:
168
- start, end = item["timestamp"][0], item["timestamp"][1]
169
- if end is None:
170
- end = start
171
- item["start"] = start
172
- item["end"] = end
173
- return result
 
 
 
 
174
 
175
  @staticmethod
176
  def download_model(
 
32
  self.model_dir = model_dir
33
  os.makedirs(self.model_dir, exist_ok=True)
34
 
35
+ self.available_models = self.get_model_paths()
 
 
36
 
37
  def transcribe(self,
38
  audio: Union[str, np.ndarray, torch.Tensor],
 
144
  model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
145
  )
146
 
147
+ def get_model_paths(self):
 
 
 
148
  """
149
+ Get available models from models path including fine-tuned model.
 
 
 
 
 
150
 
151
  Returns
152
  ----------
153
+ Name set of models
 
154
  """
155
+ openai_models = whisper.available_models()
156
+ distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
157
+ default_models = openai_models + distil_models
158
+
159
+ existing_models = os.listdir(self.model_dir)
160
+ wrong_dirs = [".locks"]
161
+
162
+ available_models = default_models + existing_models
163
+ available_models = [model for model in available_models if model not in wrong_dirs]
164
+ available_models = sorted(set(available_models), key=available_models.index)
165
+
166
+ return available_models
167
 
168
  @staticmethod
169
  def download_model(