File size: 9,939 Bytes
7464087
 
 
 
 
0c6166f
290c9c9
 
 
7464087
cf5b288
7464087
 
 
 
ba9988f
4a4232a
ee5c3c3
0c6166f
 
0f9dd39
7464087
 
 
0c6166f
 
 
 
 
 
 
ee5c3c3
 
 
 
0c6166f
 
cf5b288
23694dc
0c6166f
7464087
0c6166f
8dad166
ba9988f
 
7464087
 
 
 
629f62e
7464087
 
8dad166
 
7464087
 
 
8dad166
7464087
 
629f62e
7464087
 
 
3e997b3
 
 
 
 
 
 
7464087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629f62e
7464087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
629f62e
ee5c3c3
 
 
 
7464087
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
629f62e
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
 
7464087
ee5c3c3
7464087
629f62e
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7464087
ee5c3c3
7b62f67
ee5c3c3
7b62f67
ee5c3c3
7b62f67
ee5c3c3
7b62f67
629f62e
ee5c3c3
 
 
7b62f67
ee5c3c3
7b62f67
629f62e
ee5c3c3
 
 
 
 
 
 
 
 
 
7b62f67
ee5c3c3
7b62f67
ee5c3c3
b18ec20
629f62e
 
b18ec20
8dad166
629f62e
7b62f67
 
7464087
 
 
 
 
 
 
 
 
 
 
 
0c6166f
7464087
0c6166f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import gradio as gr
import argparse
import os, gc, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
import torch
# from pynvml import *
# nvmlInit()
# gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 4096
desc = f'''链接:<a href='https://colab.research.google.com/drive/1J1gLMMMA8GbD9JuQt6OKmwCTl9mWU0bb?usp=sharing'>太慢了?用Colab自己部署吧</a> <br /> <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a><a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a><a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a><a href="https://zhuanlan.zhihu.com/p/618011122" target="_blank" style="margin:0 0.5em">知乎教程</a>
'''

parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
parser.add_argument('--share',action='store_true') 
parser.add_argument("--world",type=bool, default=False)
parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch11.pth") 
parser.add_argument('--model_path',type=str,default=None,help="local model path") 
parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "0.2*25-31"')
args = parser.parse_args()
os.environ["RWKV_JIT_ON"] = '1'

# from rwkv.model import RWKV
from rwkv_lora import RWKV
lora_kwargs = {
    "lora":args.lora,
    "lora_alpha":args.lora_alpha,
    "lora_layer_filter":args.lora_layer_filter
}
if args.model_path:
    model_path = args.model_path
else:
    model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename=args.ckpt)
# if 'ON_COLAB' in os.environ and os.environ['ON_COLAB'] == '1':
if torch.cuda.is_available() and torch.cuda.device_count()>0:
    os.environ["RWKV_JIT_ON"] = '0'
    os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
    model = RWKV(model=model_path, strategy='cuda bf16',**lora_kwargs)
else:
    model = RWKV(model=model_path, strategy='cpu bf16',**lora_kwargs)
from utils import PIPELINE, PIPELINE_ARGS
tokenizer_file = "rwkv_vocab_v20230424" if args.world else "20B_tokenizer.json"
pipeline = PIPELINE(model, tokenizer_file)

def infer(
        ctx,
        token_count=10,
        temperature_a=1.0,
        temperature=0.7,
        top_p=1.0,
        top_k=50,
        typical_p=1.0,
        presencePenalty = 0.05,
        countPenalty = 0.05,
):
    args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), top_k=int(top_k),typical_p=float(typical_p),
                     alpha_frequency = countPenalty,
                     alpha_presence = presencePenalty,
                     temperature_a=max(0.2, float(temperature_a)),
                     token_ban = [0], # ban the generation of some tokens
                     token_stop = []) # stop generation whenever you see any token here

    # ctx = ctx.strip().split('\n')
    # for c in range(len(ctx)):
    #     ctx[c] = ctx[c].strip().strip('\u3000').strip('\r')
    # ctx = list(filter(lambda c: c != '', ctx))
    # ctx = '\n' + ('\n'.join(ctx)).strip()
    # if ctx == '':
    #     ctx = '\n'

    # gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
    # print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}',flush=True)
    
    all_tokens = []
    out_last = 0
    out_str = ''
    occurrence = {}
    state = None
    for i in range(int(token_count)):
        out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
        for n in args.token_ban:
            out[n] = -float('inf')
        for n in occurrence:
            out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)

        token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p, temperature_a=args.temperature_a)
        if token in args.token_stop:
            break
        all_tokens += [token]
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1
        
        tmp = pipeline.decode(all_tokens[out_last:])
        if '\ufffd' not in tmp:
            out_str += tmp
            yield out_str
            out_last = i + 1
    gc.collect()
    torch.cuda.empty_cache()
    yield out_str

examples = [
    ["""女招待: 欢迎光临。您远道而来,想必一定很累了吧?

深见: 不会……空气也清爽,也让我焕然一新呢

女招待: 是吗。那真是太好了

{我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。}""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
    ["""{我叫嘉祥,家里经营着一家点心店。
为了追求独当一面的目标,我离开了老家,开了一家名为"La Soleil"的新糕点店。
原本想独自一人打拼,却没想到,在搬家的箱子中发现了意想不到的人。
她叫巧克力,是我家的猫娘,没想到她竟然用这种方式跟了过来。}

嘉祥: 别以为这样就可以蒙混过去!你在干嘛啊,巧克力!

巧克力: 欸嘿嘿……那个,好、好久不见了呢,主人……

嘉祥: 昨天才在家里见过面不是吗。

巧克力: 这个……话是这么说没错啦……啊哈哈……""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
    ["""莲华: 你的目的,就是这个万华镜吧?

{莲华拿出了万华镜。}

深见: 啊……

{好像被万华镜拽过去了一般,我的腿不由自主地向它迈去}

深见: 是这个……就是这个啊……

{烨烨生辉的魔法玩具。
连接现实与梦之世界的、诱惑的桥梁。}

深见: 请让我好好看看……

{我刚想把手伸过去,莲华就一下子把它收了回去。}""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
    ["""{我叫嘉祥,有两只可爱的猫娘,名字分别是巧克力和香草。}

嘉祥: 偶尔来一次也不错。

{我坐到客厅的沙发上,拍了拍自己的大腿。}

巧克力&香草: 喵喵?

巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪

巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪

香草: 身为猫娘饲主,这点服务也是应该的对吧?

香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪

{我摸摸各自占据住我左右两腿的两颗猫头。}

嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
    ["""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。
巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。
她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。}

嘉祥: 很棒啊,巧克力!你真是懂不少东西呢!

巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
]

iface = gr.Interface(
    fn=infer,
    description=f'''这是GalGame剧本续写模型(实验性质,不保证效果)。<b>请点击例子(在页面底部)</b>,可编辑内容。这里只看输入的最后约1200字,请写好,标点规范,无错别字,否则电脑会模仿你的错误。<b>为避免占用资源,每次生成限制长度。可将输出内容复制到输入,然后继续生成</b>。推荐提高temp改善文采,降低topp改善逻辑,提高两个penalty避免重复,具体幅度请自己实验。<br /> {desc}''',
    allow_flagging="never",
    inputs=[
        gr.Textbox(lines=10, label="Prompt 输入的前文", value="""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。
巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。
她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。}

嘉祥: 很棒啊,巧克力!你真是懂不少东西呢!

巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!"""),  # prompt
        gr.Slider(10, 2000, step=10, value=200, label="token_count 每次生成的长度"),  # token_count
        gr.Slider(0.2, 2.0, step=0.1, value=0.6, label="temperature_a 过滤前温度,高则变化丰富,低则保守求稳"),  # temperature_a
        gr.Slider(0.2, 2.0, step=0.1, value=1.2, label="temperature 过滤后温度,高则变化丰富,低则保守求稳"),  # temperature
        gr.Slider(0.0, 1.0, step=0.05, value=1.0, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"),  # top_p
        gr.Slider(0, 500, step=1, value=0, label="top_k 默认0(不过滤),0以上时高则标新立异,低则循规蹈矩"),  # top_p
        gr.Slider(0.05, 1.0, step=0.05, value=0.4, label="typical_p 默认0.4,高则保留模型天性,低则试图贴近人类典型习惯"),  # top_p
        gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="presencePenalty 默认0.0,避免写过的类似字"),  # presencePenalty
        gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="countPenalty 默认0.0,额外避免写过多次的类似字"),  # countPenalty
    ],
    outputs=gr.Textbox(label="Output 输出的续写", lines=28),
    examples=examples,
    cache_examples=False,
).queue()

demo = gr.TabbedInterface(
    [iface], ["Generative"]
)

demo.queue(max_size=5)
if args.share:
    demo.launch(share=True,server_name="0.0.0.0",server_port=58888)
else:
    demo.launch(share=False,server_port=58888)