ehristoforu commited on
Commit
0975580
·
verified ·
1 Parent(s): 08ff14d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -71,18 +71,19 @@ def generate(
71
  )
72
  conversation.append({"role": "user", "content": message})
73
 
74
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt")
75
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
76
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
77
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
78
- input_ids = input_ids.to(model.device)
79
- attention_mask = input_ids["attention_mask"]
 
80
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
81
  generate_kwargs = dict(
82
- {"input_ids": input_ids},
83
  streamer=streamer,
84
  max_new_tokens=max_new_tokens,
85
- eos_token_id=tokenizer.eos_token_id,
86
  pad_token_id=tokenizer.eos_token_id,
87
  attention_mask=attention_mask,
88
  do_sample=True,
 
71
  )
72
  conversation.append({"role": "user", "content": message})
73
 
74
+ formatted = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
75
+ inputs = tokenizer(formatted, return_tensors="pt", padding=True)
76
+ #if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
77
+ # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
+ # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
79
+ inputs = inputs.to(model.device)
80
+ attention_mask = inputs["attention_mask"]
81
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
82
  generate_kwargs = dict(
83
+ {"input_ids": inputs["input_ids"]},
84
  streamer=streamer,
85
  max_new_tokens=max_new_tokens,
86
+ #eos_token_id=tokenizer.eos_token_id,
87
  pad_token_id=tokenizer.eos_token_id,
88
  attention_mask=attention_mask,
89
  do_sample=True,