|
import argparse |
|
from transformers import AutoProcessor |
|
from transformers import Wav2Vec2ProcessorWithLM |
|
from pyctcdecode import build_ctcdecoder |
|
|
|
|
|
def main(args): |
|
processor = AutoProcessor.from_pretrained(args.model_name_or_path) |
|
vocab_dict = processor.tokenizer.get_vocab() |
|
sorted_vocab_dict = { |
|
k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1]) |
|
} |
|
decoder = build_ctcdecoder( |
|
labels=list(sorted_vocab_dict.keys()), |
|
kenlm_model_path=args.kenlm_model_path, |
|
) |
|
processor_with_lm = Wav2Vec2ProcessorWithLM( |
|
feature_extractor=processor.feature_extractor, |
|
tokenizer=processor.tokenizer, |
|
decoder=decoder, |
|
) |
|
processor_with_lm.save_pretrained(args.model_name_or_path) |
|
print( |
|
f"Run: ~/bin/build_binary language_model/*.arpa language_model/5gram.bin -T $(pwd) && rm language_model/*.arpa") |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_name_or_path', default="./", help='Model name or path. Defaults to ./') |
|
parser.add_argument('--kenlm_model_path', required=True, help='Path to KenLM arpa file.') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
main(parse_args()) |
|
|