thefcraft commited on
Commit
daa11fd
·
1 Parent(s): c4c2f8d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +68 -0
main.py CHANGED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import numpy as np
4
+
5
+ with open('models.pickle', 'rb')as f:
6
+ models = pickle.load(f)
7
+
8
+ LORA_TOKEN = ''#'<|>LORA_TOKEN<|>'
9
+ # WEIGHT_TOKEN = '<|>WEIGHT_TOKEN<|>'
10
+ NOT_SPLIT_TOKEN = '<|>NOT_SPLIT_TOKEN<|>'
11
+
12
+ def sample_next(ctx:str,model,k):
13
+
14
+ ctx = ', '.join(ctx.split(', ')[-k:])
15
+ if model.get(ctx) is None:
16
+ return " "
17
+ possible_Chars = list(model[ctx].keys())
18
+ possible_values = list(model[ctx].values())
19
+
20
+ # print(possible_Chars)
21
+ # print(possible_values)
22
+
23
+ return np.random.choice(possible_Chars,p=possible_values)
24
+
25
+ def generateText(model, minLen=100, size=5):
26
+ keys = list(model.keys())
27
+ starting_sent = random.choice(keys)
28
+ k = len(random.choice(keys).split(', '))
29
+
30
+ sentence = starting_sent
31
+ ctx = ', '.join(starting_sent.split(', ')[-k:])
32
+
33
+ while True:
34
+ next_prediction = sample_next(ctx,model,k)
35
+ sentence += f", {next_prediction}"
36
+ ctx = ', '.join(sentence.split(', ')[-k:])
37
+ # if sentence.count('\n')>size: break
38
+ if '\n' in sentence: break
39
+ sentence = sentence.replace(NOT_SPLIT_TOKEN, ', ')
40
+ # sentence = re.sub(WEIGHT_TOKEN.replace('|', '\|'), lambda match: f":{random.randint(0,2)}.{random.randint(0,9)}", sentence)
41
+ # sentence = sentence.replace(":0.0", ':0.1')
42
+ # return sentence
43
+
44
+ prompt = sentence.split('\n')[0]
45
+ if len(prompt)<minLen:
46
+ prompt = generateText(model, minLen, size=1)[0]
47
+
48
+ size = size-1
49
+ if size == 0: return [prompt]
50
+ output = []
51
+ for i in range(size+1):
52
+ prompt = generateText(model, minLen, size=1)[0]
53
+ output.append(prompt)
54
+
55
+ return output
56
+ if __name__ == "__main__":
57
+ for model in models: # models = [(model, neg_model), (nsfw, neg_nsfw), (sfw, neg_sfw)]
58
+ text = generateText(model[0], k=k, minLen=300, size=5)
59
+ text_neg = generateText(model[1], k=k, minLen=300, size=5)
60
+
61
+ # print('\n'.join(text))
62
+ for i in range(len(text)):
63
+ print(text[i])
64
+ # print('negativePrompt:')
65
+ print(text_neg[i])
66
+ print('----------------------------------------------------------------')
67
+ print('********************************************************************************************************************************************************')
68
+