import gradio as gr import torch from typing import Optional, Dict, List, Union from pydantic import BaseModel, Field, field_validator, ConfigDict from gradio_i18n import Translate, gettext as _ from enum import Enum from copy import deepcopy import yaml from modules.utils.constants import * class WhisperImpl(Enum): WHISPER = "whisper" FASTER_WHISPER = "faster-whisper" INSANELY_FAST_WHISPER = "insanely_fast_whisper" class Segment(BaseModel): text: Optional[str] = Field(default=None, description="Transcription text of the segment") start: Optional[float] = Field(default=None, description="Start time of the segment") end: Optional[float] = Field(default=None, description="End time of the segment") class BaseParams(BaseModel): model_config = ConfigDict(protected_namespaces=()) def to_dict(self) -> Dict: return self.model_dump() def to_list(self) -> List: return list(self.model_dump().values()) @classmethod def from_list(cls, data_list: List) -> 'BaseParams': field_names = list(cls.model_fields.keys()) return cls(**dict(zip(field_names, data_list))) class VadParams(BaseParams): """Voice Activity Detection parameters""" vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts") threshold: float = Field( default=0.5, ge=0.0, le=1.0, description="Speech threshold for Silero VAD. Probabilities above this value are considered speech" ) min_speech_duration_ms: int = Field( default=250, ge=0, description="Final speech chunks shorter than this are discarded" ) max_speech_duration_s: float = Field( default=float("inf"), gt=0, description="Maximum duration of speech chunks in seconds" ) min_silence_duration_ms: int = Field( default=2000, ge=0, description="Minimum silence duration between speech chunks" ) speech_pad_ms: int = Field( default=400, ge=0, description="Padding added to each side of speech chunks" ) @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]: return [ gr.Checkbox( label=_("Enable Silero VAD Filter"), value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default), interactive=True, info=_("Enable this to transcribe only detected voice") ), gr.Slider( minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=defaults.get("threshold", cls.__fields__["threshold"].default), info="Lower it to be more sensitive to small sounds." ), gr.Number( label="Minimum Speech Duration (ms)", precision=0, value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default), info="Final speech chunks shorter than this time are thrown out" ), gr.Number( label="Maximum Speech Duration (s)", value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX), info="Maximum duration of speech chunks in \"seconds\"." ), gr.Number( label="Minimum Silence Duration (ms)", precision=0, value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default), info="In the end of each speech chunk wait for this time before separating it" ), gr.Number( label="Speech Padding (ms)", precision=0, value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default), info="Final speech chunks are padded by this time each side" ) ] class DiarizationParams(BaseParams): """Speaker diarization parameters""" is_diarize: bool = Field(default=False, description="Enable speaker diarization") device: str = Field(default="cuda", description="Device to run Diarization model.") hf_token: str = Field( default="", description="Hugging Face token for downloading diarization models" ) @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None, available_devices: Optional[List] = None, device: Optional[str] = None) -> List[gr.components.base.FormComponent]: return [ gr.Checkbox( label=_("Enable Diarization"), value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default), ), gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, value=defaults.get("device", device), ), gr.Textbox( label=_("HuggingFace Token"), value=defaults.get("hf_token", cls.__fields__["hf_token"].default), info=_("This is only needed the first time you download the model") ), ] class BGMSeparationParams(BaseParams): """Background music separation parameters""" is_separate_bgm: bool = Field(default=False, description="Enable background music separation") model_size: str = Field( default="UVR-MDX-NET-Inst_HQ_4", description="UVR model size" ) device: str = Field(default="cuda", description="Device to run UVR model.") segment_size: int = Field( default=256, gt=0, description="Segment size for UVR model" ) save_file: bool = Field( default=False, description="Whether to save separated audio files" ) enable_offload: bool = Field( default=True, description="Offload UVR model after transcription" ) @classmethod def to_gradio_input(cls, defaults: Optional[Dict] = None, available_devices: Optional[List] = None, device: Optional[str] = None, available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]: return [ gr.Checkbox( label=_("Enable Background Music Remover Filter"), value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default), interactive=True, info=_("Enabling this will remove background music") ), gr.Dropdown( label=_("Model"), choices=["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, value=defaults.get("model_size", cls.__fields__["model_size"].default), ), gr.Dropdown( label=_("Device"), choices=["cpu", "cuda"] if available_devices is None else available_devices, value=defaults.get("device", device), ), gr.Number( label="Segment Size", value=defaults.get("segment_size", cls.__fields__["segment_size"].default), precision=0, info="Segment size for UVR model" ), gr.Checkbox( label=_("Save separated files to output"), value=defaults.get("save_file", cls.__fields__["save_file"].default), ), gr.Checkbox( label=_("Offload sub model after removing background music"), value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default), ) ] class WhisperParams(BaseParams): """Whisper parameters""" model_size: str = Field(default="large-v2", description="Whisper model size") lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe") is_translate: bool = Field(default=False, description="Translate speech to English end-to-end") beam_size: int = Field(default=5, ge=1, description="Beam size for decoding") log_prob_threshold: float = Field( default=-1.0, description="Threshold for average log probability of sampled tokens" ) no_speech_threshold: float = Field( default=0.6, ge=0.0, le=1.0, description="Threshold for detecting silence" ) compute_type: str = Field(default="float16", description="Computation type for transcription") best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling") patience: float = Field(default=1.0, gt=0, description="Beam search patience factor") condition_on_previous_text: bool = Field( default=True, description="Use previous output as prompt for next window" ) prompt_reset_on_temperature: float = Field( default=0.5, ge=0.0, le=1.0, description="Temperature threshold for resetting prompt" ) initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window") temperature: float = Field( default=0.0, ge=0.0, description="Temperature for sampling" ) compression_ratio_threshold: float = Field( default=2.4, gt=0, description="Threshold for gzip compression ratio" ) length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty") repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens") no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition") prefix: Optional[str] = Field(default=None, description="Prefix text for first window") suppress_blank: bool = Field( default=True, description="Suppress blank outputs at start of sampling" ) suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress") max_initial_timestamp: float = Field( default=0.0, ge=0.0, description="Maximum initial timestamp" ) word_timestamps: bool = Field(default=False, description="Extract word-level timestamps") prepend_punctuations: Optional[str] = Field( default="\"'“¿([{-", description="Punctuations to merge with next word" ) append_punctuations: Optional[str] = Field( default="\"'.。,,!!??::”)]}、", description="Punctuations to merge with previous word" ) max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk") chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds") hallucination_silence_threshold: Optional[float] = Field( default=None, description="Threshold for skipping silent periods in hallucination detection" ) hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model") language_detection_threshold: Optional[float] = Field( default=None, description="Threshold for language detection probability" ) language_detection_segments: int = Field( default=1, gt=0, description="Number of segments for language detection" ) batch_size: int = Field(default=24, gt=0, description="Batch size for processing") @field_validator('lang') def validate_lang(cls, v): from modules.utils.constants import AUTOMATIC_DETECTION return None if v == AUTOMATIC_DETECTION.unwrap() else v @field_validator('suppress_tokens') def validate_supress_tokens(cls, v): import ast try: if isinstance(v, str): suppress_tokens = ast.literal_eval(v) if not isinstance(suppress_tokens, list): raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") return suppress_tokens if isinstance(v, list): return v except Exception as e: raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}") @classmethod def to_gradio_inputs(cls, defaults: Optional[Dict] = None, only_advanced: Optional[bool] = True, whisper_type: Optional[str] = None, available_models: Optional[List] = None, available_langs: Optional[List] = None, available_compute_types: Optional[List] = None, compute_type: Optional[str] = None): whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower() inputs = [] if not only_advanced: inputs += [ gr.Dropdown( label=_("Model"), choices=available_models, value=defaults.get("model_size", cls.__fields__["model_size"].default), ), gr.Dropdown( label=_("Language"), choices=available_langs, value=defaults.get("lang", AUTOMATIC_DETECTION), ), gr.Checkbox( label=_("Translate to English?"), value=defaults.get("is_translate", cls.__fields__["is_translate"].default), ), ] inputs += [ gr.Number( label="Beam Size", value=defaults.get("beam_size", cls.__fields__["beam_size"].default), precision=0, info="Beam size for decoding" ), gr.Number( label="Log Probability Threshold", value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default), info="Threshold for average log probability of sampled tokens" ), gr.Number( label="No Speech Threshold", value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default), info="Threshold for detecting silence" ), gr.Dropdown( label="Compute Type", choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types, value=defaults.get("compute_type", compute_type), info="Computation type for transcription" ), gr.Number( label="Best Of", value=defaults.get("best_of", cls.__fields__["best_of"].default), precision=0, info="Number of candidates when sampling" ), gr.Number( label="Patience", value=defaults.get("patience", cls.__fields__["patience"].default), info="Beam search patience factor" ), gr.Checkbox( label="Condition On Previous Text", value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default), info="Use previous output as prompt for next window" ), gr.Slider( label="Prompt Reset On Temperature", value=defaults.get("prompt_reset_on_temperature", cls.__fields__["prompt_reset_on_temperature"].default), minimum=0, maximum=1, step=0.01, info="Temperature threshold for resetting prompt" ), gr.Textbox( label="Initial Prompt", value=defaults.get("initial_prompt", GRADIO_NONE_STR), info="Initial prompt for first window" ), gr.Slider( label="Temperature", value=defaults.get("temperature", cls.__fields__["temperature"].default), minimum=0.0, step=0.01, maximum=1.0, info="Temperature for sampling" ), gr.Number( label="Compression Ratio Threshold", value=defaults.get("compression_ratio_threshold", cls.__fields__["compression_ratio_threshold"].default), info="Threshold for gzip compression ratio" ) ] faster_whisper_inputs = [ gr.Number( label="Length Penalty", value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default), info="Exponential length penalty", ), gr.Number( label="Repetition Penalty", value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default), info="Penalty for repeated tokens" ), gr.Number( label="No Repeat N-gram Size", value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default), precision=0, info="Size of n-grams to prevent repetition" ), gr.Textbox( label="Prefix", value=defaults.get("prefix", GRADIO_NONE_STR), info="Prefix text for first window" ), gr.Checkbox( label="Suppress Blank", value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default), info="Suppress blank outputs at start of sampling" ), gr.Textbox( label="Suppress Tokens", value=defaults.get("suppress_tokens", "[-1]"), info="Token IDs to suppress" ), gr.Number( label="Max Initial Timestamp", value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default), info="Maximum initial timestamp" ), gr.Checkbox( label="Word Timestamps", value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default), info="Extract word-level timestamps" ), gr.Textbox( label="Prepend Punctuations", value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default), info="Punctuations to merge with next word" ), gr.Textbox( label="Append Punctuations", value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default), info="Punctuations to merge with previous word" ), gr.Number( label="Max New Tokens", value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN), precision=0, info="Maximum number of new tokens per chunk" ), gr.Number( label="Chunk Length (s)", value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default), precision=0, info="Length of audio segments in seconds" ), gr.Number( label="Hallucination Silence Threshold (sec)", value=defaults.get("hallucination_silence_threshold", GRADIO_NONE_NUMBER_MIN), info="Threshold for skipping silent periods in hallucination detection" ), gr.Textbox( label="Hotwords", value=defaults.get("hotwords", cls.__fields__["hotwords"].default), info="Hotwords/hint phrases for the model" ), gr.Number( label="Language Detection Threshold", value=defaults.get("language_detection_threshold", GRADIO_NONE_NUMBER_MIN), info="Threshold for language detection probability" ), gr.Number( label="Language Detection Segments", value=defaults.get("language_detection_segments", cls.__fields__["language_detection_segments"].default), precision=0, info="Number of segments for language detection" ) ] insanely_fast_whisper_inputs = [ gr.Number( label="Batch Size", value=defaults.get("batch_size", cls.__fields__["batch_size"].default), precision=0, info="Batch size for processing" ) ] if whisper_type != WhisperImpl.FASTER_WHISPER.value: for input_component in faster_whisper_inputs: input_component.visible = False if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value: for input_component in insanely_fast_whisper_inputs: input_component.visible = False inputs += faster_whisper_inputs + insanely_fast_whisper_inputs return inputs class TranscriptionPipelineParams(BaseModel): """Transcription pipeline parameters""" whisper: WhisperParams = Field(default_factory=WhisperParams) vad: VadParams = Field(default_factory=VadParams) diarization: DiarizationParams = Field(default_factory=DiarizationParams) bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams) def to_dict(self) -> Dict: data = { "whisper": self.whisper.to_dict(), "vad": self.vad.to_dict(), "diarization": self.diarization.to_dict(), "bgm_separation": self.bgm_separation.to_dict() } return data def to_list(self) -> List: """ Convert data class to the list because I have to pass the parameters as a list in the gradio. Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 See more about Gradio pre-processing: https://www.gradio.app/docs/components """ whisper_list = self.whisper.to_list() vad_list = self.vad.to_list() diarization_list = self.diarization.to_list() bgm_sep_list = self.bgm_separation.to_list() return whisper_list + vad_list + diarization_list + bgm_sep_list @staticmethod def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams': """Convert list to the data class again to use it in a function.""" data_list = deepcopy(pipeline_list) whisper_list = data_list[0:len(WhisperParams.__annotations__)] data_list = data_list[len(WhisperParams.__annotations__):] vad_list = data_list[0:len(VadParams.__annotations__)] data_list = data_list[len(VadParams.__annotations__):] diarization_list = data_list[0:len(DiarizationParams.__annotations__)] data_list = data_list[len(DiarizationParams.__annotations__):] bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)] return TranscriptionPipelineParams( whisper=WhisperParams.from_list(whisper_list), vad=VadParams.from_list(vad_list), diarization=DiarizationParams.from_list(diarization_list), bgm_separation=BGMSeparationParams.from_list(bgm_sep_list) )