Chris4K commited on
Commit
2bbcd9b
·
verified ·
1 Parent(s): 9aae24c

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +158 -0
utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #util.py
2
+ """Some utility functions for the app."""
3
+ from base64 import b64encode
4
+ from io import BytesIO
5
+
6
+ from gtts import gTTS
7
+ from mtranslate import translate
8
+ from speech_recognition import AudioFile, Recognizer
9
+ from transformers import (BlenderbotSmallForConditionalGeneration,
10
+ BlenderbotSmallTokenizer)
11
+
12
+
13
+ def stt(audio: object, language: str) -> str:
14
+ """Converts speech to text.
15
+
16
+ Args:
17
+ audio: record of user speech
18
+
19
+ Returns:
20
+ text (str): recognized speech of user
21
+ """
22
+ r = Recognizer()
23
+ # open the audio file
24
+ with AudioFile(audio) as source:
25
+ # listen for the data (load audio to memory)
26
+ audio_data = r.record(source)
27
+ # recognize (convert from speech to text)
28
+ text = r.recognize_google(audio_data, language=language)
29
+ return text
30
+
31
+
32
+ def to_en_translation(text: str, language: str) -> str:
33
+ """Translates text from specified language to English.
34
+
35
+ Args:
36
+ text (str): input text
37
+ language (str): desired language
38
+
39
+ Returns:
40
+ str: translated text
41
+ """
42
+ return translate(text, "en", language)
43
+
44
+
45
+ def from_en_translation(text: str, language: str) -> str:
46
+ """Translates text from english to specified language.
47
+
48
+ Args:
49
+ text (str): input text
50
+ language (str): desired language
51
+
52
+ Returns:
53
+ str: translated text
54
+ """
55
+ return translate(text, language, "en")
56
+
57
+
58
+ class TextGenerationPipeline:
59
+ """Pipeline for text generation of blenderbot model.
60
+
61
+ Returns:
62
+ str: generated text
63
+ """
64
+
65
+ # load tokenizer and the model
66
+ model_name = "facebook/blenderbot_small-90M"
67
+ tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_name)
68
+ model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_name)
69
+
70
+ def __init__(self, **kwargs):
71
+ """Specififying text generation parameters.
72
+
73
+ For example: max_length=100 which generates text shorter than
74
+ 100 tokens. Visit:
75
+ https://huggingface.co/docs/transformers/main_classes/text_generation
76
+ for more parameters
77
+ """
78
+ self.__dict__.update(kwargs)
79
+
80
+ def preprocess(self, text) -> str:
81
+ """Tokenizes input text.
82
+
83
+ Args:
84
+ text (str): user specified text
85
+
86
+ Returns:
87
+ torch.Tensor (obj): text representation as tensors
88
+ """
89
+ return self.tokenizer(text, return_tensors="pt")
90
+
91
+ def postprocess(self, outputs) -> str:
92
+ """Converts tensors into text.
93
+
94
+ Args:
95
+ outputs (torch.Tensor obj): model text generation output
96
+
97
+ Returns:
98
+ str: generated text
99
+ """
100
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
101
+
102
+ def __call__(self, text: str) -> str:
103
+ """Generates text from input text.
104
+
105
+ Args:
106
+ text (str): user specified text
107
+
108
+ Returns:
109
+ str: generated text
110
+ """
111
+ tokenized_text = self.preprocess(text)
112
+ output = self.model.generate(**tokenized_text, **self.__dict__)
113
+ return self.postprocess(output)
114
+
115
+
116
+ def tts(text: str, language: str) -> object:
117
+ """Converts text into audio object.
118
+
119
+ Args:
120
+ text (str): generated answer of bot
121
+
122
+ Returns:
123
+ object: text to speech object
124
+ """
125
+ return gTTS(text=text, lang=language, slow=False)
126
+
127
+
128
+ def tts_to_bytesio(tts_object: object) -> bytes:
129
+ """Converts tts object to bytes.
130
+
131
+ Args:
132
+ tts_object (object): audio object obtained from gtts
133
+
134
+ Returns:
135
+ bytes: audio bytes
136
+ """
137
+ bytes_object = BytesIO()
138
+ tts_object.write_to_fp(bytes_object)
139
+ bytes_object.seek(0)
140
+ return bytes_object.getvalue()
141
+
142
+
143
+ def html_audio_autoplay(bytes: bytes) -> object:
144
+ """Creates html object for autoplaying audio at gradio app.
145
+
146
+ Args:
147
+ bytes (bytes): audio bytes
148
+
149
+ Returns:
150
+ object: html object that provides audio autoplaying
151
+ """
152
+ b64 = b64encode(bytes).decode()
153
+ html = f"""
154
+ <audio controls autoplay>
155
+ <source src="data:audio/wav;base64,{b64}" type="audio/wav">
156
+ </audio>
157
+ """
158
+ return html