File size: 9,939 Bytes
7464087 0c6166f 290c9c9 7464087 cf5b288 7464087 ba9988f 4a4232a ee5c3c3 0c6166f 0f9dd39 7464087 0c6166f ee5c3c3 0c6166f cf5b288 23694dc 0c6166f 7464087 0c6166f 8dad166 ba9988f 7464087 629f62e 7464087 8dad166 7464087 8dad166 7464087 629f62e 7464087 3e997b3 7464087 629f62e 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 629f62e ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 629f62e ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 629f62e ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7464087 ee5c3c3 7b62f67 ee5c3c3 7b62f67 ee5c3c3 7b62f67 ee5c3c3 7b62f67 629f62e ee5c3c3 7b62f67 ee5c3c3 7b62f67 629f62e ee5c3c3 7b62f67 ee5c3c3 7b62f67 ee5c3c3 b18ec20 629f62e b18ec20 8dad166 629f62e 7b62f67 7464087 0c6166f 7464087 0c6166f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import gradio as gr
import argparse
import os, gc, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
import torch
# from pynvml import *
# nvmlInit()
# gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 4096
desc = f'''链接:<a href='https://colab.research.google.com/drive/1J1gLMMMA8GbD9JuQt6OKmwCTl9mWU0bb?usp=sharing'>太慢了?用Colab自己部署吧</a> <br /> <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a><a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a><a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a><a href="https://zhuanlan.zhihu.com/p/618011122" target="_blank" style="margin:0 0.5em">知乎教程</a>
'''
parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
parser.add_argument('--share',action='store_true')
parser.add_argument("--world",type=bool, default=False)
parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch11.pth")
parser.add_argument('--model_path',type=str,default=None,help="local model path")
parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "0.2*25-31"')
args = parser.parse_args()
os.environ["RWKV_JIT_ON"] = '1'
# from rwkv.model import RWKV
from rwkv_lora import RWKV
lora_kwargs = {
"lora":args.lora,
"lora_alpha":args.lora_alpha,
"lora_layer_filter":args.lora_layer_filter
}
if args.model_path:
model_path = args.model_path
else:
model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename=args.ckpt)
# if 'ON_COLAB' in os.environ and os.environ['ON_COLAB'] == '1':
if torch.cuda.is_available() and torch.cuda.device_count()>0:
os.environ["RWKV_JIT_ON"] = '0'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
model = RWKV(model=model_path, strategy='cuda bf16',**lora_kwargs)
else:
model = RWKV(model=model_path, strategy='cpu bf16',**lora_kwargs)
from utils import PIPELINE, PIPELINE_ARGS
tokenizer_file = "rwkv_vocab_v20230424" if args.world else "20B_tokenizer.json"
pipeline = PIPELINE(model, tokenizer_file)
def infer(
ctx,
token_count=10,
temperature_a=1.0,
temperature=0.7,
top_p=1.0,
top_k=50,
typical_p=1.0,
presencePenalty = 0.05,
countPenalty = 0.05,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), top_k=int(top_k),typical_p=float(typical_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
temperature_a=max(0.2, float(temperature_a)),
token_ban = [0], # ban the generation of some tokens
token_stop = []) # stop generation whenever you see any token here
# ctx = ctx.strip().split('\n')
# for c in range(len(ctx)):
# ctx[c] = ctx[c].strip().strip('\u3000').strip('\r')
# ctx = list(filter(lambda c: c != '', ctx))
# ctx = '\n' + ('\n'.join(ctx)).strip()
# if ctx == '':
# ctx = '\n'
# gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
# print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}',flush=True)
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p, temperature_a=args.temperature_a)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str
out_last = i + 1
gc.collect()
torch.cuda.empty_cache()
yield out_str
examples = [
["""女招待: 欢迎光临。您远道而来,想必一定很累了吧?
深见: 不会……空气也清爽,也让我焕然一新呢
女招待: 是吗。那真是太好了
{我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。}""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
["""{我叫嘉祥,家里经营着一家点心店。
为了追求独当一面的目标,我离开了老家,开了一家名为"La Soleil"的新糕点店。
原本想独自一人打拼,却没想到,在搬家的箱子中发现了意想不到的人。
她叫巧克力,是我家的猫娘,没想到她竟然用这种方式跟了过来。}
嘉祥: 别以为这样就可以蒙混过去!你在干嘛啊,巧克力!
巧克力: 欸嘿嘿……那个,好、好久不见了呢,主人……
嘉祥: 昨天才在家里见过面不是吗。
巧克力: 这个……话是这么说没错啦……啊哈哈……""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
["""莲华: 你的目的,就是这个万华镜吧?
{莲华拿出了万华镜。}
深见: 啊……
{好像被万华镜拽过去了一般,我的腿不由自主地向它迈去}
深见: 是这个……就是这个啊……
{烨烨生辉的魔法玩具。
连接现实与梦之世界的、诱惑的桥梁。}
深见: 请让我好好看看……
{我刚想把手伸过去,莲华就一下子把它收了回去。}""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
["""{我叫嘉祥,有两只可爱的猫娘,名字分别是巧克力和香草。}
嘉祥: 偶尔来一次也不错。
{我坐到客厅的沙发上,拍了拍自己的大腿。}
巧克力&香草: 喵喵?
巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪
巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪
香草: 身为猫娘饲主,这点服务也是应该的对吧?
香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪
{我摸摸各自占据住我左右两腿的两颗猫头。}
嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
["""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。
巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。
她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。}
嘉祥: 很棒啊,巧克力!你真是懂不少东西呢!
巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
]
iface = gr.Interface(
fn=infer,
description=f'''这是GalGame剧本续写模型(实验性质,不保证效果)。<b>请点击例子(在页面底部)</b>,可编辑内容。这里只看输入的最后约1200字,请写好,标点规范,无错别字,否则电脑会模仿你的错误。<b>为避免占用资源,每次生成限制长度。可将输出内容复制到输入,然后继续生成</b>。推荐提高temp改善文采,降低topp改善逻辑,提高两个penalty避免重复,具体幅度请自己实验。<br /> {desc}''',
allow_flagging="never",
inputs=[
gr.Textbox(lines=10, label="Prompt 输入的前文", value="""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。
巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。
她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。}
嘉祥: 很棒啊,巧克力!你真是懂不少东西呢!
巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!"""), # prompt
gr.Slider(10, 2000, step=10, value=200, label="token_count 每次生成的长度"), # token_count
gr.Slider(0.2, 2.0, step=0.1, value=0.6, label="temperature_a 过滤前温度,高则变化丰富,低则保守求稳"), # temperature_a
gr.Slider(0.2, 2.0, step=0.1, value=1.2, label="temperature 过滤后温度,高则变化丰富,低则保守求稳"), # temperature
gr.Slider(0.0, 1.0, step=0.05, value=1.0, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"), # top_p
gr.Slider(0, 500, step=1, value=0, label="top_k 默认0(不过滤),0以上时高则标新立异,低则循规蹈矩"), # top_p
gr.Slider(0.05, 1.0, step=0.05, value=0.4, label="typical_p 默认0.4,高则保留模型天性,低则试图贴近人类典型习惯"), # top_p
gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="presencePenalty 默认0.0,避免写过的类似字"), # presencePenalty
gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="countPenalty 默认0.0,额外避免写过多次的类似字"), # countPenalty
],
outputs=gr.Textbox(label="Output 输出的续写", lines=28),
examples=examples,
cache_examples=False,
).queue()
demo = gr.TabbedInterface(
[iface], ["Generative"]
)
demo.queue(max_size=5)
if args.share:
demo.launch(share=True,server_name="0.0.0.0",server_port=58888)
else:
demo.launch(share=False,server_port=58888) |