Mattral commited on
Commit
732e00c
·
verified ·
1 Parent(s): bd7b4bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -102
app.py CHANGED
@@ -1,118 +1,64 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import random
4
- import textwrap
5
- from transformers import pipeline
6
- import numpy as np
7
 
8
- # Load the Whisper model for automatic speech recognition
9
- transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
10
-
11
- # Define the model to be used
12
  model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
13
  client = InferenceClient(model)
14
 
15
- # Embedded system prompt
16
- system_prompt_text = (
17
- "You are a smart and helpful co-worker of Thailand based multi-national company PTT, and PTTEP. "
18
- "You help with any kind of request and provide a detailed answer to the question. But if you are asked about something "
19
- "unethical or dangerous, you must refuse and provide a safe and respectful way to handle that."
20
- )
21
-
22
- # Function to transcribe audio input
23
- def transcribe(audio):
24
- if audio is None:
25
- return None # Handle case where audio input is None
26
-
27
- sr, y = audio
28
- # Convert to mono if stereo
29
- if y.ndim > 1:
30
- y = y.mean(axis=1)
31
-
32
- y = y.astype(np.float32)
33
- y /= np.max(np.abs(y)) # Normalize audio
34
-
35
- return transcriber({"sampling_rate": sr, "raw": y})["text"] # Transcribe audio
36
-
37
- def format_prompt_mixtral(message, history):
38
- prompt = "<s>"
39
- prompt += f"{system_prompt_text}\n\n" # Add the system prompt
40
-
41
- if history:
42
- for user_prompt, bot_response in history:
43
- prompt += f"[INST] {user_prompt} [/INST]"
44
- prompt += f" {bot_response}</s> "
45
- prompt += f"[INST] {message} [/INST]"
46
- return prompt
47
-
48
- def chat_inf(prompt, history, seed, temp, tokens, top_p, rep_p):
49
- generate_kwargs = dict(
50
- temperature=temp,
51
- max_new_tokens=tokens,
52
- top_p=top_p,
53
- repetition_penalty=rep_p,
54
- do_sample=True,
55
- seed=seed,
56
- )
57
-
58
- formatted_prompt = format_prompt_mixtral(prompt, history)
59
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
60
  output = ""
61
- for response in stream:
62
- output += response.token.text
63
- yield [(prompt, output)]
64
- history.append((prompt, output))
65
- yield history
66
 
67
- def clear_fn():
68
- return None, None
69
 
70
- rand_val = random.randint(1, 1111111111111111)
 
 
71
 
72
- def check_rand(inp, val):
73
- if inp:
74
- return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
75
- else:
76
- return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
77
 
 
78
  with gr.Blocks() as app:
79
- gr.HTML("""<center><h1 style='font-size:xx-large;'>PTT Chatbot</h1><br><h3>running on Huggingface Inference</h3><br><h7>EXPERIMENTAL</center>""")
80
 
81
- with gr.Row():
82
- chat = gr.Chatbot(height=500)
83
-
84
- with gr.Group():
85
- with gr.Row():
86
- with gr.Column(scale=3):
87
- inp = gr.Audio(type="filepath") # Remove the source parameter
88
- with gr.Row():
89
- with gr.Column(scale=2):
90
- btn = gr.Button("Chat")
91
- with gr.Column(scale=1):
92
- with gr.Group():
93
- stop_btn = gr.Button("Stop")
94
- clear_btn = gr.Button("Clear")
95
- with gr.Column(scale=1):
96
- with gr.Group():
97
- rand = gr.Checkbox(label="Random Seed", value=True)
98
- seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
99
- tokens = gr.Slider(label="Max new tokens", value=3840, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="The maximum number of tokens")
100
- temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
101
- top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
102
- rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
103
-
104
- def handle_chat(audio_input, chat_history, seed, temp, tokens, top_p, rep_p):
105
- user_message = transcribe(audio_input) # Transcribe audio to text
106
- if user_message is None or user_message == "": # Check for empty or error in recognition
107
- return chat_history, "Sorry, I couldn't understand that."
108
-
109
- response_gen = chat_inf(user_message, chat_history, seed, temp, tokens, top_p, rep_p)
110
- response = next(response_gen)[0][-1][1] # Get the response text
111
- return chat_history + [(user_message, response)], response # Return updated chat history
112
-
113
- go = btn.click(handle_chat, [inp, chat, seed, temp, tokens, top_p, rep_p], chat)
114
 
115
- stop_btn.click(None, None, None, cancels=[go])
116
- clear_btn.click(clear_fn, None, [inp, chat])
117
 
118
- app.queue(default_concurrency_limit=10).launch(share=True, auth=("admin", "0112358")) # Launch the app with authentication
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import random
 
 
 
4
 
5
+ # Initialize the model
 
 
 
6
  model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
7
  client = InferenceClient(model)
8
 
9
+ def chat_response(prompt, history, seed, temp, tokens, top_p, rep_p):
10
+ generate_kwargs = {
11
+ "temperature": temp,
12
+ "max_new_tokens": tokens,
13
+ "top_p": top_p,
14
+ "repetition_penalty": rep_p,
15
+ "do_sample": True,
16
+ "seed": seed,
17
+ }
18
+
19
+ # Include the chat history in the prompt
20
+ formatted_prompt = "\n".join([f"Q: {user_prompt}\nA: {bot_response}" for user_prompt, bot_response in history]) + f"\nQ: {prompt}\nA:"
21
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  output = ""
23
+
24
+ # Generating text in streaming mode
25
+ for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=True):
26
+ # Assuming response is directly a string or contains a message
27
+ output += response # Using response directly since it's a string
28
 
29
+ # Yield the updated output for real-time display
30
+ yield [(prompt, output)]
31
 
32
+ # Append the full response to history after completion
33
+ history.append((prompt, output))
34
+ yield history # Yielding the updated history
35
 
36
+ def clear_chat():
37
+ return [], [] # Returning an empty history
 
 
 
38
 
39
+ # Gradio interface
40
  with gr.Blocks() as app:
41
+ gr.HTML("<center><h1>Chatbot</h1><h3>Ask your questions!</h3></center>")
42
 
43
+ chat_box = gr.Chatbot(height=500)
44
+ inp = gr.Textbox(label="Your Question", lines=5)
45
+ btn = gr.Button("Ask")
46
+ clear_btn = gr.Button("Clear")
47
+
48
+ rand_seed = gr.Checkbox(label="Random Seed", value=True)
49
+ seed_slider = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
50
+ tokens_slider = gr.Slider(label="Max new tokens", value=3840, minimum=0, maximum=8000)
51
+ temp_slider = gr.Slider(label="Temperature", value=0.9, minimum=0.01, maximum=1.0)
52
+ top_p_slider = gr.Slider(label="Top-P", value=0.9, minimum=0.01, maximum=1.0)
53
+ rep_p_slider = gr.Slider(label="Repetition Penalty", value=1.0, minimum=0.1, maximum=2.0)
54
+
55
+ # Handle button click to get chat response
56
+ btn.click(
57
+ lambda prompt: chat_response(prompt, [], seed_slider.value, temp_slider.value, tokens_slider.value, top_p_slider.value, rep_p_slider.value),
58
+ inp,
59
+ chat_box,
60
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ clear_btn.click(clear_chat, None, [inp, chat_box])
 
63
 
64
+ app.launch(share=True, auth=("admin", "0112358"))