wanicca commited on
Commit
629f62e
·
1 Parent(s): 0f9dd39

加入提前温度,修复top_p=1时的一些bug

Browse files
Files changed (2) hide show
  1. app.py +11 -8
  2. utils.py +13 -8
app.py CHANGED
@@ -45,6 +45,7 @@ pipeline = PIPELINE(model, "20B_tokenizer.json")
45
  def infer(
46
  ctx,
47
  token_count=10,
 
48
  temperature=0.7,
49
  top_p=1.0,
50
  top_k=50,
@@ -55,6 +56,7 @@ def infer(
55
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), top_k=int(top_k),typical_p=float(typical_p),
56
  alpha_frequency = countPenalty,
57
  alpha_presence = presencePenalty,
 
58
  token_ban = [0], # ban the generation of some tokens
59
  token_stop = []) # stop generation whenever you see any token here
60
 
@@ -81,7 +83,7 @@ def infer(
81
  for n in occurrence:
82
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
83
 
84
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p)
85
  if token in args.token_stop:
86
  break
87
  all_tokens += [token]
@@ -106,7 +108,7 @@ examples = [
106
 
107
  女招待: 是吗。那真是太好了
108
 
109
- {我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。}""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1],
110
  ["""{我叫嘉祥,家里经营着一家点心店。
111
  为了追求独当一面的目标,我离开了老家,开了一家名为"La Soleil"的新糕点店。
112
  原本想独自一人打拼,却没想到,在搬家的箱子中发现了意想不到的人。
@@ -118,7 +120,7 @@ examples = [
118
 
119
  嘉祥: 昨天才在家里见过面不是吗。
120
 
121
- 巧克力: 这个……话是这么说没错啦……啊哈哈……""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1],
122
  ["""莲华: 你的目的,就是这个万华镜吧?
123
 
124
  {莲华拿出了万华镜。}
@@ -134,7 +136,7 @@ examples = [
134
 
135
  深见: 请让我好好看看……
136
 
137
- {我刚想把手伸过去,莲华就一下子把它收了回去。}""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1],
138
  ["""{我叫嘉祥,有两只可爱的猫娘,名字分别是巧克力和香草。}
139
 
140
  嘉祥: 偶尔来一次也不错。
@@ -153,14 +155,14 @@ examples = [
153
 
154
  {我摸摸各自占据住我左右两腿的两颗猫头。}
155
 
156
- 嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1],
157
  ["""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。
158
  巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。
159
  她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。}
160
 
161
  嘉祥: 很棒啊,巧克力!你真是懂不少东西呢!
162
 
163
- 巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1],
164
  ]
165
 
166
  iface = gr.Interface(
@@ -176,10 +178,11 @@ iface = gr.Interface(
176
 
177
  巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!"""), # prompt
178
  gr.Slider(10, 2000, step=10, value=200, label="token_count 每次生成的长度"), # token_count
179
- gr.Slider(0.2, 2.0, step=0.1, value=0.7, label="temperature 默认0.7,高则变化丰富,低则保守求稳"), # temperature
 
180
  gr.Slider(0.0, 1.0, step=0.05, value=1.0, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"), # top_p
181
  gr.Slider(0, 500, step=1, value=0, label="top_k 默认0(不过滤),0以上时高则标新立异,低则循规蹈矩"), # top_p
182
- gr.Slider(0.05, 1.0, step=0.05, value=1.0, label="typical_p 默认1.0,高则保留模型天性,低则试图贴近人类典型习惯"), # top_p
183
  gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="presencePenalty 默认0.0,避免写过的类似字"), # presencePenalty
184
  gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="countPenalty 默认0.0,额外避免写过多次的类似字"), # countPenalty
185
  ],
 
45
  def infer(
46
  ctx,
47
  token_count=10,
48
+ temperature_a=1.0,
49
  temperature=0.7,
50
  top_p=1.0,
51
  top_k=50,
 
56
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), top_k=int(top_k),typical_p=float(typical_p),
57
  alpha_frequency = countPenalty,
58
  alpha_presence = presencePenalty,
59
+ temperature_a=max(0.2, float(temperature_a)),
60
  token_ban = [0], # ban the generation of some tokens
61
  token_stop = []) # stop generation whenever you see any token here
62
 
 
83
  for n in occurrence:
84
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
85
 
86
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p, temperature_a=args.temperature_a)
87
  if token in args.token_stop:
88
  break
89
  all_tokens += [token]
 
108
 
109
  女招待: 是吗。那真是太好了
110
 
111
+ {我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。}""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
112
  ["""{我叫嘉祥,家里经营着一家点心店。
113
  为了追求独当一面的目标,我离开了老家,开了一家名为"La Soleil"的新糕点店。
114
  原本想独自一人打拼,却没想到,在搬家的箱子中发现了意想不到的人。
 
120
 
121
  嘉祥: 昨天才在家里见过面不是吗。
122
 
123
+ 巧克力: 这个……话是这么说没错啦……啊哈哈……""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
124
  ["""莲华: 你的目的,就是这个万华镜吧?
125
 
126
  {莲华拿出了万华镜。}
 
136
 
137
  深见: 请让我好好看看……
138
 
139
+ {我刚想把手伸过去,莲华就一下子把它收了回去。}""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
140
  ["""{我叫嘉祥,有两只可爱的猫娘,名字分别是巧克力和香草。}
141
 
142
  嘉祥: 偶尔来一次也不错。
 
155
 
156
  {我摸摸各自占据住我左右两腿的两颗猫头。}
157
 
158
+ 嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
159
  ["""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。
160
  巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。
161
  她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。}
162
 
163
  嘉祥: 很棒啊,巧克力!你真是懂不少东西呢!
164
 
165
+ 巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!""", 200, 0.6, 1.2, 1.0, 0, 0.4, 0.1, 0.1],
166
  ]
167
 
168
  iface = gr.Interface(
 
178
 
179
  巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!"""), # prompt
180
  gr.Slider(10, 2000, step=10, value=200, label="token_count 每次生成的长度"), # token_count
181
+ gr.Slider(0.2, 2.0, step=0.1, value=0.6, label="temperature_a 过滤前温度,高则变化丰富,低则保守求稳"), # temperature_a
182
+ gr.Slider(0.2, 2.0, step=0.1, value=1.2, label="temperature 过滤后温度,高则变化丰富,低则保守求稳"), # temperature
183
  gr.Slider(0.0, 1.0, step=0.05, value=1.0, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"), # top_p
184
  gr.Slider(0, 500, step=1, value=0, label="top_k 默认0(不过滤),0以上时高则标新立异,低则循规蹈矩"), # top_p
185
+ gr.Slider(0.05, 1.0, step=0.05, value=0.4, label="typical_p 默认0.4,高则保留模型天性,低则试图贴近人类典型习惯"), # top_p
186
  gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="presencePenalty 默认0.0,避免写过的类似字"), # presencePenalty
187
  gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="countPenalty 默认0.0,额外避免写过多次的类似字"), # countPenalty
188
  ],
utils.py CHANGED
@@ -4,13 +4,14 @@ import torch
4
  from torch.nn import functional as F
5
 
6
  class PIPELINE_ARGS():
7
- def __init__(self, temperature=1.0, top_p=0.85, top_k=0, typical_p=1, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[], chunk_len=256):
8
  self.temperature = temperature
9
  self.top_p = top_p
10
  self.top_k = top_k
11
  self.typical_p = typical_p
12
  self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
13
  self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
 
14
  self.token_ban = token_ban # ban the generation of some tokens
15
  self.token_stop = token_stop # stop generation whenever you see any token here
16
  self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
@@ -44,7 +45,9 @@ class PIPELINE():
44
  def decode(self, x):
45
  return self.tokenizer.decode(x)
46
 
47
- def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0,typical_p=1):
 
 
48
  probs = F.softmax(logits.float(), dim=-1)
49
  top_k = int(top_k)
50
  if typical_p<1:
@@ -54,14 +57,15 @@ class PIPELINE():
54
  sorted_typical_scores = typical_scores[typical_sorted_ids]
55
  typical_sorted_probs = probs[typical_sorted_ids]
56
  cum_typical_sorted_probs = torch.cumsum(typical_sorted_probs, dim=-1).cpu().numpy()
57
- typical_cutoff = float(sorted_typical_scores[np.argmax(cum_typical_sorted_probs > typical_p)])
58
  if probs.device == torch.device('cpu'):
59
  probs = probs.numpy()
60
  sorted_ids = np.argsort(probs)
61
  sorted_probs = probs[sorted_ids][::-1]
62
  cumulative_probs = np.cumsum(sorted_probs)
63
- cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
64
- probs[probs < cutoff] = 0
 
65
  if top_k < len(probs) and top_k > 0:
66
  probs[sorted_ids[:-top_k]] = 0
67
  if typical_p<1:
@@ -76,8 +80,9 @@ class PIPELINE():
76
  sorted_probs = probs[sorted_ids]
77
  sorted_probs = torch.flip(sorted_probs, dims=(0,))
78
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
79
- cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
80
- probs[probs < cutoff] = 0
 
81
  if top_k < len(probs) and top_k > 0:
82
  probs[sorted_ids[:-top_k]] = 0
83
  if typical_p<1:
@@ -106,7 +111,7 @@ class PIPELINE():
106
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
107
 
108
  # sampler
109
- token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p)
110
  if token in args.token_stop:
111
  break
112
  all_tokens += [token]
 
4
  from torch.nn import functional as F
5
 
6
  class PIPELINE_ARGS():
7
+ def __init__(self, temperature=1.0, top_p=0.85, top_k=0, typical_p=1, alpha_frequency=0.2, alpha_presence=0.2, temperature_a=1.0,token_ban=[], token_stop=[], chunk_len=256):
8
  self.temperature = temperature
9
  self.top_p = top_p
10
  self.top_k = top_k
11
  self.typical_p = typical_p
12
  self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
13
  self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
14
+ self.temperature_a = temperature_a
15
  self.token_ban = token_ban # ban the generation of some tokens
16
  self.token_stop = token_stop # stop generation whenever you see any token here
17
  self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
 
45
  def decode(self, x):
46
  return self.tokenizer.decode(x)
47
 
48
+ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0,typical_p=1,temperature_a=1.0):
49
+ if temperature_a != 1.0:
50
+ logits = logits / temperature_a
51
  probs = F.softmax(logits.float(), dim=-1)
52
  top_k = int(top_k)
53
  if typical_p<1:
 
57
  sorted_typical_scores = typical_scores[typical_sorted_ids]
58
  typical_sorted_probs = probs[typical_sorted_ids]
59
  cum_typical_sorted_probs = torch.cumsum(typical_sorted_probs, dim=-1).cpu().numpy()
60
+ typical_cutoff = float(sorted_typical_scores[np.argmax(cum_typical_sorted_probs >= typical_p)])
61
  if probs.device == torch.device('cpu'):
62
  probs = probs.numpy()
63
  sorted_ids = np.argsort(probs)
64
  sorted_probs = probs[sorted_ids][::-1]
65
  cumulative_probs = np.cumsum(sorted_probs)
66
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
67
+ if top_p < 1:
68
+ probs[probs < cutoff] = 0
69
  if top_k < len(probs) and top_k > 0:
70
  probs[sorted_ids[:-top_k]] = 0
71
  if typical_p<1:
 
80
  sorted_probs = probs[sorted_ids]
81
  sorted_probs = torch.flip(sorted_probs, dims=(0,))
82
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
83
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
84
+ if top_p < 1:
85
+ probs[probs < cutoff] = 0
86
  if top_k < len(probs) and top_k > 0:
87
  probs[sorted_ids[:-top_k]] = 0
88
  if typical_p<1:
 
111
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
112
 
113
  # sampler
114
+ token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p,temperature_a=args.temperature_a)
115
  if token in args.token_stop:
116
  break
117
  all_tokens += [token]