Hyperparameter?
Hi, I'm working on fine-tuning whisper on Korean datasets to make good korean asr model. I have about 30k hours dataset.
But, fine-tuned large-v3, large-v3-turbo model has very poor performance now. I think dataset is okay since fine-tuned whisper-base get very good performance. CER goes down from 26.44 to 11.54 with this dataset.
Can you share your hyper parameter? optimizer setting, learning rate, learning rate scheduler, batch size and total training steps, etc.
one more question, is this model fine-tuned from large-v3, not large-v1?
Hi! From our experience and the experience of the National Library of Norway: The larger the model you are training, the more important regularization becomes.
We follow many of the hyperparameters and general tips for data preparation that can be found in the paper "Whispering in Norwegian: Navigating Ortographic and Dialectic Challenges" (Kummervold et al.). In terms of regularization we use:
- BPE dropout of 0.2 (makes the tokenizer tokenize the label text differently by introducing randomness when merging subwords). BPE dropout is only used during training. In validation and eval when the model is finished we return to using the regular tokenizer without dropout.
- Activation dropout of 0.1.
You need to use a workaround in Hugging Face to be able to load a BPE tokenizer with dropout, and here's how we save and load a tokenizer with this feature:
# 3. Regularization settings.
# a) BPE dropout in the tokenizer (randomly uses different subwords to encode the same word)
if custom_args.bpe_dropout > 0:
# Need a workaround to successfully load the tokenizer with BPE dropout.
# See https://github.com/huggingface/tokenizers/issues/201#issue-584182286
# Should only be used for training, not for inference/eval.
logger.info(f"cache_dir_tokenizer: {custom_args.cache_dir_tokenizer}")
with training_args.main_process_first():
if is_main_process(training_args.local_rank):
workaround_files = tokenizer._tokenizer.model.save(custom_args.cache_dir_tokenizer, "training_tokenizer")
workaround_tokenizer = os.path.join(custom_args.cache_dir_tokenizer, "training_tokenizer-vocab.json")
workaround_merges = os.path.join(custom_args.cache_dir_tokenizer, "training_tokenizer-merges.txt")
workaround_files = [workaround_tokenizer, workaround_merges]
tokenizer._tokenizer.model = BPE.from_file(*workaround_files, dropout=custom_args.bpe_dropout)
I think the regularization along with extensive quality filters for the data is the most important part. In terms of data preparation:
- From subtitle sources, and force aligned data sources where sentences are aligned on sentence level, we create "candidate chunks" up to 30s, where non-speech parts are also included as part of these chunks to the extent that they fit.
- We explicitly also create candidate chunks that are entirely non-speech, and train Whisper to predict
<|nospeech|>
. - All candidate chunks are langdetected and transcribed with different existing Swedish ASR models. We calculate BLEU, weighted ROUGE-N, CER for first 10 and last 10 characters of a chunk, for every single chunk that is a candidate to be included.
- We apply light filtering in stage 1 training. But apply stricter filtering in stage 2 to steer the model to a specific style of transcription and to reduce hallucinations.
- We train alternating between no timestamps, with timestamps, with previous context as prompt and without prompt. When training with timestamps, we ensure CER of first 10 characters and last 10 characters between machine transcription and ground truth is
<= 0.2
-ish, so the model gets a consistent signal when to output timestamps. - We deliberately create extra candidate chunks that are shorter and augment or dataset with these, so the model learns to handle audio of variable length, and learns to output
<|endoftext|>
properly for shorter chunks as well as for longer chunks.
See this dataset for the type of quality metrics we use and the way we generally organize our data before training: https://huggingface.co./datasets/KBLab/rixvox-v2
Our training code is a bit customized for our own setup since we separated most of the data preparation and training logic compared to HF finetuning scripts. But we will upload it nonetheless and I'll link it. I think the main addition we make to HF's existing finetuning script is improve the DataCollator to handle timestamps, tokenization on the fly with BPE dropout, and previous context as prompt.
Here's our DataCollator:
@dataclass
class DataCollatorSpeechSeq2SeqWithBPEDropoutPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
feature_extractor (`any`)
Mel spectogram feature extractor from audio.
tokenizer (`any`)
The BPE dropout tokenizer used to tokenize text.
decoder_start_token_id (`int`)
The begin-of-sentence of the decoder.
timestamp_probability (`float`)
Probability of using timestamp labels.
prompt_probability (`float`)
Probability of using previous text as prompt.
seed (`int`)
Seed for reproducibility.
truncation (`bool`)
Whether to truncate the input text to max_length.
max_length (`int`)
The maximum length of the tokenized input text + special tokens.
max_previous_text_length (`int`)
The maximum length of the tokenized previous text to use as prompt.
"""
feature_extractor: Any
tokenizer: Any
decoder_start_token_id: int
timestamp_probability: float = 0.8
prompt_probability: float = 0.7
seed: int = None # seed for reproducibility
truncation: bool = True
max_length: int = 448
max_previous_text_length: int = 192
generation_max_length: int = 448
model_name: str = None
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
if self.seed is not None:
random.seed(self.seed)
# split inputs and labels since they have to be of different lengths and need
# different padding methods
model_input_name = "input_features"
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
texts, texts_with_timestamps, previous_texts = self._get_texts_from_features(features)
# Replace "<|nospeech|>" with "<|nocaptions|>" to avoid tokenization errors if not using large model
if "large" not in self.model_name:
texts = ["<|nocaptions|>" if text == "<|nospeech|>" else text for text in texts]
texts_with_timestamps = ["<|nocaptions|>" if text == "<|nospeech|>" else text for text in texts_with_timestamps]
previous_texts = ["<|nocaptions|>" if text == "<|nospeech|>" else text for text in previous_texts]
# BPE dropout on the fly introduces randomness, and will lead to longer tokenized sequences
# than in our pre-processing step. We need to handle when `len(input_ids) > max_length`
# with truncation and max_length.
self.tokenizer.set_prefix_tokens(predict_timestamps=False)
tokenized_texts = self.tokenizer(
texts, truncation=self.truncation, max_length=self.max_length
)
self.tokenizer.set_prefix_tokens(predict_timestamps=True)
tokenized_text_timestamps = self.tokenizer(
texts_with_timestamps, truncation=self.truncation, max_length=self.max_length
)
self.tokenizer.set_prefix_tokens(predict_timestamps=False)
tokenized_previous_texts = self._tokenize_prompt(
previous_texts,
self.tokenizer,
truncation=True,
max_length=self.max_previous_text_length,
)
combined_input_ids = []
for i, feature in enumerate(features):
combined_tokens = []
# Use timestamp labels with probability `timestamp_probability` whenever observation is suitable
if (random.random() < self.timestamp_probability) and feature[
"stage2_whisper_timestamps"
]:
input_ids_to_concat = tokenized_text_timestamps["input_ids"][i]
# attention_masks.append(tokenized_text_timestamps["attention_mask"][i])
else:
input_ids_to_concat = tokenized_texts["input_ids"][i]
# attention_masks.append(tokenized_texts["attention_mask"][i])
combined_length = len(tokenized_previous_texts[i]) + len(input_ids_to_concat)
previous_text_is_empty = len(tokenized_previous_texts[i]) <= 1
train_with_prompt = random.random() < self.prompt_probability
if "large" not in self.model_name:
previous_text_is_nospeech = tokenized_previous_texts[i][0] == self.tokenizer.convert_tokens_to_ids("<|nocaptions|>")
else:
previous_text_is_nospeech = tokenized_previous_texts[i][0] == self.tokenizer.convert_tokens_to_ids("<|nospeech|>")
if (
(combined_length <= self.max_length)
and train_with_prompt
and not previous_text_is_empty
and not previous_text_is_nospeech
):
combined_tokens = tokenized_previous_texts[i].copy()
combined_tokens.extend(input_ids_to_concat)
else:
combined_tokens.extend(input_ids_to_concat)
combined_input_ids.append(combined_tokens)
# Pad the labels to the same length
labels = self._pad_labels(combined_input_ids, self.tokenizer.pad_token_id)
batch = self.feature_extractor.pad(input_features, return_tensors="pt")
# We manually "shift" the labels to the right and ensure decoder_input_ids are same length:
# original labels: <SoT> 1 2 3 4 <EoT>
# decoder input: <SoT> 1 2 3 4
# shifted labels: 1 2 3 4 <EoT>
decoder_input_ids = labels[:, :-1].clone() # truncate last token
labels = labels[:, 1:] # truncate first token
endoftext_token_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
first_endoftext = (labels == endoftext_token_id).int().argmax(dim=1)
padding_mask = torch.arange(labels.shape[1]).unsqueeze(0) > first_endoftext.unsqueeze(1)
# Mask the positions that are -100 to create attention mask
attention_masks = padding_mask.ne(-100)
# Replace padding tokens with -100 to ignore them in loss calculation
labels.masked_fill_(padding_mask, -100)
# Replace initial prompt tokens with -100 to ignore them in loss calculation
sot_index = torch.argmax(
(labels == self.tokenizer.convert_tokens_to_ids("<|startoftranscript|>")).int(), dim=1
)
prompt_mask = torch.arange(labels.shape[1]).unsqueeze(0) < sot_index.unsqueeze(1)
labels.masked_fill_(prompt_mask, -100)
# If we manually pass decoder_input_ids, HF won't shift the labels to the right, and instead use the decoder_input_ids we provided
# https://github.com/huggingface/transformers/blob/504c4d36929b6bb8a8c2ecfad0f2625f4075f22a/src/transformers/models/whisper/modeling_whisper.py#L1761-L1764
batch["decoder_input_ids"] = decoder_input_ids
batch["labels"] = labels
# Padding tokens only present on the right side, so will never attend to previous tokens due to causal mask
# batch["attention_mask"] = attention_masks
return batch
def _pad_labels(self, labels: List[List[int]], pad_token_id: int) -> torch.Tensor:
"""
Regular texts and texts with timestamps are processed independently and have different lengths
when combined. This function pads the combined labels to the same length.
"""
labels = [torch.tensor(label) for label in labels]
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=pad_token_id
)
return labels
def _tokenize_prompt(self, texts, tokenizer, truncation=False, max_length=192):
"""
Tokenize the prompt text and add <|startofprev|> token at the beginning.
Left truncate the tokenized text to max_length if necessary.
Args:
texts: list of texts
tokenizer: tokenizer
truncation: whether to truncate the tokenized text to max_length of the model
"""
tokenizer.set_prefix_tokens(predict_timestamps=False)
start_of_prev_token = tokenizer.convert_tokens_to_ids("<|startofprev|>")
prompt_tokens = tokenizer(texts, truncation=truncation, add_special_tokens=False).input_ids
# Left truncate to (max_length - 1) to fit in added <|startofprev|>
prompt_tokens_with_start = []
for prompt_token in prompt_tokens:
truncated_token = prompt_token[-(max_length - 1) :]
prompt_with_start_token = [start_of_prev_token]
prompt_with_start_token.extend(truncated_token)
prompt_tokens_with_start.append(prompt_with_start_token)
return prompt_tokens_with_start
def _get_texts_from_features(self, features: List[Dict[str, Any]]) -> List[List[str]]:
"""
Extract different texts from features. I.e.
- regular text
- text with whisper timestamps
- previous text (text to be used as prompt)
Args:
features: list of features
add_start_of_prev: whether to add <|startofprev|> token at the beginning of previous text
Returns:
texts: list of regular texts
text_with_timestamps: list of texts with timestamps
previous_texts: list of previous texts
"""
texts = []
text_with_timestamps = []
previous_texts = []
for feature in features:
texts.append(feature["text"])
text_with_timestamps.append(feature["text_timestamps"])
previous_text = (
feature["previous_text"] if feature["previous_text"] is not None else ""
)
previous_texts.append(previous_text)
return texts, text_with_timestamps, previous_texts
And here's the settings we use when launching the model (with learning rate, global batch sizes etc) for stage1 training:
torchrun \
--nproc_per_node=$NPROC_PER_NODE \
--nnodes=$SLURM_JOB_NUM_NODES \
--node_rank=$SLURM_NODEID \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
scripts/run_speech_recognition_seq2seq_streaming_bpe_previous.py \
--deepspeed=$CONFIG_DIR"/ds_config_large.json" \
--model_name_or_path="/leonardo_work/EUHPC_A01_006/models/whisper-large-v3" \
--node_id=$SLURM_NODEID \
--proc_id=$SLURM_PROCID \
--dataset_name="/leonardo_scratch/large/userexternal/jsikora0/interleave_large/interleave_large_stage1_batch_size_128/" \
--language="swedish" \
--train_split_name="train" \
--max_steps="150000" \
--max_eval_samples="2048" \
--dataloader_num_workers="5" \
--cache_dir=/leonardo_scratch/large/userexternal/jsikora0/cache \
--cache_dir_tokenizer=/leonardo_scratch/large/userexternal/jsikora0/cache_tokenizer \
--output_dir=outputs/2024-12-26_large-stage1-workers5-fa2 \
--per_device_train_batch_size="16" \
--gradient_accumulation_steps="2" \
--per_device_eval_batch_size="4" \
--logging_steps="10" \
--learning_rate="7e-5" \
--warmup_steps="5000" \
--eval_strategy="steps" \
--eval_steps="750" \
--save_steps="750" \
--save_strategy="steps" \
--max_length="448" \
--generation_max_length="448" \
--max_duration_in_seconds="30" \
--text_column_name="input_features" \
--audio_column_name="input_features" \
--freeze_feature_encoder="False" \
--report_to="tensorboard" \
--metric_for_best_model="wer" \
--greater_is_better="False" \
--dispatch_batches="False" \
--gradient_checkpointing=True \
--ignore_data_skip=True \
--adam_epsilon="1e-6" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--weight_decay="0.01" \
--lr_scheduler_type="linear" \
--bpe_dropout="0.2" \
--fp16 \
--do_train \
--do_eval \
--streaming \
--shuffle_buffer_size="8" \
--predict_with_generate \
--stamps_probs="0.5" \
--prompt_probability="0.5" \
--remove_unused_columns="False" \
--activation_dropout="0.1" \
--seed="479" \
In stage2 we use slightly lower learning rate (--learning_rate="3e-6"
). Our global batch size is 1024. We train on 8 nodes that each have 4x A100 with 64GB VRAM. Here's our deepspeed config:
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"last_batch_iteration": -1,
"total_num_steps": "auto",
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 1,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}
one more question, is this model fine-tuned from large-v3, not large-v1?
Finetuned from large-v3.
Another important tip if you are using parquet to store data and datasets
library's functionality to load data in streaming mode from parquets: It is very important to set row_group_size
to a smaller value when saving parquets (around 100 is OK). This allows streaming parquet readers to load 100 rows at a time from a shard, as opposed to load the entire parquet file into CPU memory.
In HuggingFace datasets
they have named this argument batch_size
in the .to_parquet()
method.
Your training will easily go OOM in terms of regular RAM when not using this argument for parquet.
Huge thanks for detailed reply! It really helps. I'm gonna try 2-stage training with data filtering, and regularization.