import os import time import numpy as np import torch from typing import BinaryIO, Union, Tuple, List import faster_whisper from faster_whisper.vad import VadOptions import ast import ctranslate2 import whisper import gradio as gr from argparse import Namespace from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) from modules.whisper.data_classes import * from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline class FasterWhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = FASTER_WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, uvr_model_dir: str = UVR_MODELS_DIR, output_dir: str = OUTPUT_DIR, ): super().__init__( model_dir=model_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir, output_dir=output_dir ) self.model_dir = model_dir os.makedirs(self.model_dir, exist_ok=True) self.model_paths = self.get_model_paths() self.device = self.get_device() self.available_models = self.model_paths.keys() def transcribe(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), *whisper_params, ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. Parameters ---------- audio: Union[str, BinaryIO, np.ndarray] Audio path or file binary or Audio numpy array progress: gr.Progress Indicator to show progress directly in gradio. *whisper_params: tuple Parameters related with whisper. This will be dealt with "WhisperParameters" data class Returns ---------- segments_result: List[Segment] list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ start_time = time.time() params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) segments, info = self.model.transcribe( audio=audio, language=params.lang, task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", beam_size=params.beam_size, log_prob_threshold=params.log_prob_threshold, no_speech_threshold=params.no_speech_threshold, best_of=params.best_of, patience=params.patience, temperature=params.temperature, initial_prompt=params.initial_prompt, compression_ratio_threshold=params.compression_ratio_threshold, length_penalty=params.length_penalty, repetition_penalty=params.repetition_penalty, no_repeat_ngram_size=params.no_repeat_ngram_size, prefix=params.prefix, suppress_blank=params.suppress_blank, suppress_tokens=params.suppress_tokens, max_initial_timestamp=params.max_initial_timestamp, word_timestamps=params.word_timestamps, prepend_punctuations=params.prepend_punctuations, append_punctuations=params.append_punctuations, max_new_tokens=params.max_new_tokens, chunk_length=params.chunk_length, hallucination_silence_threshold=params.hallucination_silence_threshold, hotwords=params.hotwords, language_detection_threshold=params.language_detection_threshold, language_detection_segments=params.language_detection_segments, prompt_reset_on_temperature=params.prompt_reset_on_temperature, ) progress(0, desc="Loading audio..") segments_result = [] for segment in segments: progress(segment.start / info.duration, desc="Transcribing..") segments_result.append(Segment( start=segment.start, end=segment.end, text=segment.text )) elapsed_time = time.time() - start_time return segments_result, elapsed_time def update_model(self, model_size: str, compute_type: str, progress: gr.Progress = gr.Progress() ): """ Update current model setting Parameters ---------- model_size: str Size of whisper model compute_type: str Compute type for transcription. see more info : https://opennmt.net/CTranslate2/quantization.html progress: gr.Progress Indicator to show progress directly in gradio. """ progress(0, desc="Initializing Model..") self.current_model_size = self.model_paths[model_size] self.current_compute_type = compute_type self.model = faster_whisper.WhisperModel( device=self.device, model_size_or_path=self.current_model_size, download_root=self.model_dir, compute_type=self.current_compute_type ) def get_model_paths(self): """ Get available models from models path including fine-tuned model. Returns ---------- Name list of models """ model_paths = {model:model for model in faster_whisper.available_models()} faster_whisper_prefix = "models--Systran--faster-whisper-" existing_models = os.listdir(self.model_dir) wrong_dirs = [".locks"] existing_models = list(set(existing_models) - set(wrong_dirs)) for model_name in existing_models: if faster_whisper_prefix in model_name: model_name = model_name[len(faster_whisper_prefix):] if model_name not in whisper.available_models(): model_paths[model_name] = os.path.join(self.model_dir, model_name) return model_paths @staticmethod def get_device(): if torch.cuda.is_available(): return "cuda" else: return "auto" @staticmethod def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]: try: suppress_tokens = ast.literal_eval(suppress_tokens_str) if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens): raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") return suppress_tokens except Exception as e: raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")