yixinsong commited on
Commit
25976f2
·
1 Parent(s): c878e28
Files changed (1) hide show
  1. app.py +47 -12
app.py CHANGED
@@ -11,7 +11,7 @@ import uuid
11
  import json
12
 
13
  # Constants
14
- SYSTEM_PROMPT = """You are SmallThinker-3B, a helpful AI assistant. You try to follow instructions as much as possible while being accurate and brief."""
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  TITLE = "<h1><center>SmallThinker-3B Chat</center></h1>"
17
  MODEL_PATH = "PowerInfer/SmallThinker-3B-Preview"
@@ -53,7 +53,7 @@ button {
53
  # Load model and tokenizer
54
  model = AutoModelForCausalLM.from_pretrained(
55
  MODEL_PATH,
56
- torch_dtype=torch.bfloat16,
57
  ).to(device)
58
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
59
 
@@ -129,7 +129,13 @@ def stream_chat(
129
  # with logs_file.open("a") as f:
130
  # f.write(json.dumps({"input": input_text.replace(SYSTEM_PROMPT, ""), "output": buffer.replace(SYSTEM_PROMPT, ""), "model": "SmallThinker-3B"}))
131
  # f.write("\n")
132
-
 
 
 
 
 
 
133
  def clear_input():
134
  return ""
135
 
@@ -140,7 +146,12 @@ def add_message(message: str, history: list):
140
 
141
  def clear_session() -> Tuple[str, List]:
142
  return '', []
143
-
 
 
 
 
 
144
  def main():
145
  with gr.Blocks(css=CSS, theme="soft") as demo:
146
  gr.HTML(TITLE)
@@ -167,39 +178,63 @@ def main():
167
  with gr.Row():
168
  clear_history = gr.Button("🧹 Clear History")
169
  submit = gr.Button("🚀 Send")
 
170
 
171
  # Chain of events for submit button
172
  submit_event = submit.click(
173
- fn=add_message,
174
  inputs=[textbox, chatbot],
175
- outputs=chatbot,
176
  queue=False
177
  ).then(
178
- fn=clear_input,
179
- outputs=textbox,
 
180
  queue=False
181
  ).then(
182
  fn=stream_chat,
183
  inputs=[textbox, chatbot, temperature, max_new_tokens, top_p, top_k, repetition_penalty],
184
  outputs=chatbot,
185
  show_progress=True
 
 
 
 
 
 
 
186
  )
187
 
188
  # Chain of events for enter key
189
  enter_event = textbox.submit(
190
- fn=add_message,
191
  inputs=[textbox, chatbot],
192
- outputs=chatbot,
193
  queue=False
194
  ).then(
195
- fn=clear_input,
196
- outputs=textbox,
 
197
  queue=False
198
  ).then(
199
  fn=stream_chat,
200
  inputs=[textbox, chatbot, temperature, max_new_tokens, top_p, top_k, repetition_penalty],
201
  outputs=chatbot,
202
  show_progress=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
 
205
  clear_history.click(fn=clear_session,
 
11
  import json
12
 
13
  # Constants
14
+ SYSTEM_PROMPT = """You are a helpful assistant."""
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  TITLE = "<h1><center>SmallThinker-3B Chat</center></h1>"
17
  MODEL_PATH = "PowerInfer/SmallThinker-3B-Preview"
 
53
  # Load model and tokenizer
54
  model = AutoModelForCausalLM.from_pretrained(
55
  MODEL_PATH,
56
+ torch_dtype=torch.float16,
57
  ).to(device)
58
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
59
 
 
129
  # with logs_file.open("a") as f:
130
  # f.write(json.dumps({"input": input_text.replace(SYSTEM_PROMPT, ""), "output": buffer.replace(SYSTEM_PROMPT, ""), "model": "SmallThinker-3B"}))
131
  # f.write("\n")
132
+ def stop_generation():
133
+ stop_event.set()
134
+ return {
135
+ stop_btn: gr.Button.update(interactive=False),
136
+ submit: gr.Button.update(interactive=True),
137
+ textbox: gr.Textbox.update(interactive=True)
138
+ }
139
  def clear_input():
140
  return ""
141
 
 
146
 
147
  def clear_session() -> Tuple[str, List]:
148
  return '', []
149
+ def on_submit(textbox, chatbot, *args):
150
+ return {
151
+ textbox: gr.Textbox.update(value="", interactive=False),
152
+ submit: gr.Button.update(interactive=False),
153
+ stop_btn: gr.Button.update(interactive=True),
154
+ }
155
  def main():
156
  with gr.Blocks(css=CSS, theme="soft") as demo:
157
  gr.HTML(TITLE)
 
178
  with gr.Row():
179
  clear_history = gr.Button("🧹 Clear History")
180
  submit = gr.Button("🚀 Send")
181
+ stop_btn = gr.Button("🛑 Stop", interactive=False)
182
 
183
  # Chain of events for submit button
184
  submit_event = submit.click(
185
+ fn=on_submit,
186
  inputs=[textbox, chatbot],
187
+ outputs=[textbox, submit, stop_btn],
188
  queue=False
189
  ).then(
190
+ fn=add_message,
191
+ inputs=[textbox, chatbot],
192
+ outputs=chatbot,
193
  queue=False
194
  ).then(
195
  fn=stream_chat,
196
  inputs=[textbox, chatbot, temperature, max_new_tokens, top_p, top_k, repetition_penalty],
197
  outputs=chatbot,
198
  show_progress=True
199
+ ).then(
200
+ fn=lambda: {
201
+ textbox: gr.Textbox.update(interactive=True),
202
+ submit: gr.Button.update(interactive=True),
203
+ stop_btn: gr.Button.update(interactive=False)
204
+ },
205
+ outputs=[textbox, submit, stop_btn]
206
  )
207
 
208
  # Chain of events for enter key
209
  enter_event = textbox.submit(
210
+ fn=on_submit,
211
  inputs=[textbox, chatbot],
212
+ outputs=[textbox, submit, stop_btn],
213
  queue=False
214
  ).then(
215
+ fn=add_message,
216
+ inputs=[textbox, chatbot],
217
+ outputs=chatbot,
218
  queue=False
219
  ).then(
220
  fn=stream_chat,
221
  inputs=[textbox, chatbot, temperature, max_new_tokens, top_p, top_k, repetition_penalty],
222
  outputs=chatbot,
223
  show_progress=True
224
+ ).then(
225
+ fn=lambda: {
226
+ textbox: gr.Textbox.update(interactive=True),
227
+ submit: gr.Button.update(interactive=True),
228
+ stop_btn: gr.Button.update(interactive=False)
229
+ },
230
+ outputs=[textbox, submit, stop_btn]
231
+ )
232
+
233
+ # Stop button event
234
+ stop_btn.click(
235
+ fn=stop_generation,
236
+ outputs=[stop_btn, submit, textbox],
237
+ queue=False
238
  )
239
 
240
  clear_history.click(fn=clear_session,