OzoneAsai commited on
Commit
6e8f3cd
1 Parent(s): 9ef3882

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import streamlit as st
4
+
5
+ st.title("Japanese Text Generation")
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False)
8
+ model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo")
9
+
10
+ logs = []
11
+
12
+ def generate_text(input_prompt):
13
+ token_ids = tokenizer.encode(input_prompt, add_special_tokens=False, return_tensors="pt")
14
+
15
+ with torch.no_grad():
16
+ output_ids = model.generate(
17
+ token_ids.to("cpu"),
18
+ do_sample=True,
19
+ max_new_tokens=128,
20
+ temperature=0.7,
21
+ repetition_penalty=1.1,
22
+ pad_token_id=tokenizer.pad_token_id,
23
+ bos_token_id=tokenizer.bos_token_id,
24
+ eos_token_id=tokenizer.eos_token_id
25
+ )
26
+
27
+ generated_text = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
28
+ generated_text = generated_text.replace("<NL>", "\n")
29
+ return generated_text
30
+
31
+ prompt = st.text_area("Enter the prompt:")
32
+
33
+ if st.button("Submit"):
34
+ generated_output = generate_text(prompt)
35
+ logs.append((prompt, generated_output))
36
+
37
+ for log in logs:
38
+ with st.beta_container():
39
+ st.write("---")
40
+ st.subheader("Time: {}".format(log[0]))
41
+ st.write("**Input**: {}".format(log[0]))
42
+ st.write("**Output**: {}".format(log[1]))