jhj0517 commited on
Commit
4c4ecfe
·
1 Parent(s): 2a2f7c6

Update to use enum

Browse files
tests/test_bgm_separation.py CHANGED
@@ -1,6 +1,6 @@
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
- from modules.whisper.data_classes import TranscriptionPipelineParams
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
@@ -17,9 +17,9 @@ import os
17
  @pytest.mark.parametrize(
18
  "whisper_type,vad_filter,bgm_separation,diarization",
19
  [
20
- ("whisper", False, True, False),
21
- ("faster-whisper", False, True, False),
22
- ("insanely_fast_whisper", False, True, False)
23
  ]
24
  )
25
  def test_bgm_separation_pipeline(
@@ -38,9 +38,9 @@ def test_bgm_separation_pipeline(
38
  @pytest.mark.parametrize(
39
  "whisper_type,vad_filter,bgm_separation,diarization",
40
  [
41
- ("whisper", True, True, False),
42
- ("faster-whisper", True, True, False),
43
- ("insanely_fast_whisper", True, True, False)
44
  ]
45
  )
46
  def test_bgm_separation_with_vad_pipeline(
 
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.data_classes import *
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
 
17
  @pytest.mark.parametrize(
18
  "whisper_type,vad_filter,bgm_separation,diarization",
19
  [
20
+ (WhisperImpl.WHISPER.value, False, True, False),
21
+ (WhisperImpl.FASTER_WHISPER.value, False, True, False),
22
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False)
23
  ]
24
  )
25
  def test_bgm_separation_pipeline(
 
38
  @pytest.mark.parametrize(
39
  "whisper_type,vad_filter,bgm_separation,diarization",
40
  [
41
+ (WhisperImpl.WHISPER.value, True, True, False),
42
+ (WhisperImpl.FASTER_WHISPER.value, True, True, False),
43
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False)
44
  ]
45
  )
46
  def test_bgm_separation_with_vad_pipeline(
tests/test_diarization.py CHANGED
@@ -1,6 +1,6 @@
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
- from modules.whisper.data_classes import TranscriptionPipelineParams
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
@@ -16,9 +16,9 @@ import os
16
  @pytest.mark.parametrize(
17
  "whisper_type,vad_filter,bgm_separation,diarization",
18
  [
19
- ("whisper", False, False, True),
20
- ("faster-whisper", False, False, True),
21
- ("insanely_fast_whisper", False, False, True)
22
  ]
23
  )
24
  def test_diarization_pipeline(
 
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.data_classes import *
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
 
16
  @pytest.mark.parametrize(
17
  "whisper_type,vad_filter,bgm_separation,diarization",
18
  [
19
+ (WhisperImpl.WHISPER.value, False, False, True),
20
+ (WhisperImpl.FASTER_WHISPER.value, False, False, True),
21
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True)
22
  ]
23
  )
24
  def test_diarization_pipeline(
tests/test_transcription.py CHANGED
@@ -12,9 +12,9 @@ import os
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
- ("whisper", False, False, False),
16
- ("faster-whisper", False, False, False),
17
- ("insanely_fast_whisper", False, False, False)
18
  ]
19
  )
20
  def test_transcribe(
 
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
+ (WhisperImpl.WHISPER.value, False, False, False),
16
+ (WhisperImpl.FASTER_WHISPER.value, False, False, False),
17
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
18
  ]
19
  )
20
  def test_transcribe(
tests/test_vad.py CHANGED
@@ -1,6 +1,6 @@
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
- from modules.whisper.data_classes import TranscriptionPipelineParams
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
@@ -12,9 +12,9 @@ import os
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
- ("whisper", True, False, False),
16
- ("faster-whisper", True, False, False),
17
- ("insanely_fast_whisper", True, False, False)
18
  ]
19
  )
20
  def test_vad_pipeline(
 
1
  from modules.utils.paths import *
2
  from modules.whisper.whisper_factory import WhisperFactory
3
+ from modules.whisper.data_classes import *
4
  from test_config import *
5
  from test_transcription import download_file, test_transcribe
6
 
 
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
15
+ (WhisperImpl.WHISPER.value, True, False, False),
16
+ (WhisperImpl.FASTER_WHISPER.value, True, False, False),
17
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False)
18
  ]
19
  )
20
  def test_vad_pipeline(