kz919 commited on
Commit
6d0db7f
·
verified ·
1 Parent(s): 5fb8783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import spaces
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
4
 
5
  # Load the model and tokenizer locally
6
  model_name = "kz919/QwQ-0.5B-Distilled-SFT"
@@ -30,16 +30,27 @@ def respond(message, history: list[tuple[str, str]], system_message, max_tokens,
30
  # Tokenize the input prompt
31
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
32
 
33
- # Generate a response
34
- outputs = model.generate(
35
- inputs.input_ids,
36
- max_length=max_tokens,
37
- temperature=temperature,
38
- top_p=top_p,
39
- pad_token_id=tokenizer.eos_token_id,
40
- streamer = TextStreamer(tokenizer)
 
 
 
 
 
 
 
41
  )
42
- yield outputs
 
 
 
 
43
 
44
 
45
  # Create the Gradio interface
 
1
  import spaces
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
 
5
  # Load the model and tokenizer locally
6
  model_name = "kz919/QwQ-0.5B-Distilled-SFT"
 
30
  # Tokenize the input prompt
31
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
32
 
33
+
34
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
35
+
36
+ # Use a thread to run the generation in parallel
37
+ generation_thread = threading.Thread(
38
+ target=model.generate,
39
+ kwargs=dict(
40
+ inputs=inputs.input_ids,
41
+ max_length=max_tokens,
42
+ streamer=streamer,
43
+ do_sample=True,
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ pad_token_id=tokenizer.eos_token_id,
47
+ ),
48
  )
49
+ generation_thread.start()
50
+
51
+ # Stream the tokens as they are generated
52
+ for new_text in streamer:
53
+ yield new_text
54
 
55
 
56
  # Create the Gradio interface