Spaces:
Sleeping
Sleeping
File size: 7,061 Bytes
013fb26 f42affe 013fb26 2dce12e 013fb26 ef797fb 2dce12e 013fb26 3b7d642 013fb26 2dce12e 013fb26 71d0e66 013fb26 2dce12e 013fb26 2dce12e 013fb26 2dce12e 979a861 2dce12e 013fb26 b013fb5 013fb26 2dce12e 013fb26 2dce12e d55c571 2dce12e 013fb26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import re
from unittest import result
import string
import streamlit as st
import torch
from torch.nn import functional as F
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoTokenizer,
GPT2Tokenizer, LogitsProcessor, LogitsProcessorList,
pipeline, top_k_top_p_filtering, PhrasalConstraint, DisjunctiveConstraint)
import ast
class ModifyLogitsProcessor(LogitsProcessor):
### Anything with the letter "e" in it
def __init__(self, tokenizer, chars_to_modify, filter_mode=True):
super().__init__()
self.tokenizer = tokenizer
self.filter_mode = filter_mode
self.chars_to_modify = chars_to_modify
# Compute the tokens to modify at initialization
self.tokens_to_modify = {}
for char, factor in chars_to_modify.items():
mod_tokens = [token_id for token_id, token in enumerate(self.tokenizer.get_vocab()) if char in token]
self.tokens_to_modify[char] = mod_tokens
def __call__(self, input_ids, scores):
for char, tokens in self.tokens_to_modify.items():
if self.filter_mode:
scores[:, tokens] = -float('inf')
else:
# Fetch the corresponding factor from chars_to_modify dictionary
factor = self.chars_to_modify[char]
scores[:, tokens] += factor
return scores
st.set_page_config(page_title="Gadsby")
st.title("Gadsby - Constrained Text Generation with Transformers")
st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg")
st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)")
form = st.sidebar.form("choose_settings")
form.header("Model Settings")
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "TheBloke/vicuna-7B-1.1-HF")
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
percision = form.selectbox("What percision are we loading the model with?", ["8bit", "16bit", "32bit"], )
form.caption("The lower the percision, the less ram the model takes and the faster it runs, but the quality is reduced")
form.header("Token Level Constraint Settings")
form.subheader("Lipogram Constraint")
form.caption("Lipograms are compositions where a certain letter or certain letters of the alphabet are omitted or discouraged")
filter_mode = form.checkbox("Filter Mode?", value=False)
form.caption("Enabling filter mode sets all selected tokens probabilities to negative infinity")
naughty_strings_list = form.text_input('Enter letters or words to filter or modify the probabilities of (comma separated):', value = "that,e")
factor_input = form.text_input('Enter corresponding factors to add to the logits (comma separated, ignored if in filter mode):', value = "5,-99")
form.header("Sequence Level Constraint Settings")
form.header("Phrasal Constraint")
force_word = form.text_input("Enter a word or sentence that is guaranteed to appear in the output", value = "lipogram")
form.header("Disjunctive Constraint")
force_flexible_input = form.text_input('Enter a list of words or sentences that the model must include at least one item from (in Python list format)', '["constraint", "banana"]')
if force_flexible_input:
try:
force_flexible = ast.literal_eval(force_flexible_input)
except Exception as e:
st.write('Failed to parse the list. Please check your input.')
st.write('Error:', e)
force_flexible = []
else:
pass
if naughty_strings_list:
chars = naughty_strings_list.split(',')
factors = list(map(float, factor_input.split(',')))
chars_to_modify = dict(zip(chars, factors))
else:
chars = ""
factors = []
chars_to_modify = {}
generate_args = st.text_input('model.generate() arguments (in python dictionary format) ', '{"max_new_tokens": 50, "min_new_tokens": 50, "temperature": 2.0, "num_return_sequences": 1, "do_sample": False, "num_beams": 2, "repetition_penalty": 3.0}')
st.caption("For more details on what these settings mean and a complete list of all settings, see here: https://huggingface.co./blog/how-to-generate and https://huggingface.co./docs/transformers/main_classes/text_generation#transformers.GenerationConfig and https://huggingface.co./docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationMixin.generate")
custom_prompt = """
### Human: Write about how much you love constrained text generation techniques
### Assistant:
"""
sequence = st.text_area("Enter a custom prompt", value = custom_prompt)
form.form_submit_button("Generate some Constrained Text!")
def parse_generate_args(args_str):
args_list = args_str.split(',')
args_dict = {arg.split(':')[0]: int(arg.split(':')[1]) for arg in args_list if len(arg.split(':')) == 2}
return args_dict
@st.cache_resource
def load_the_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False)
return tokenizer
@st.cache_resource
def load_the_model(percision):
if percision == "32bit":
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False)
elif percision =="16bit":
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False, torch_dtype=torch.float16)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True)
return model
if len(chars) != len(factors):
st.write("Please ensure that the number of characters matches the number of factors.")
else:
model = load_the_model(percision)
tokenizer = load_the_tokenizer()
constraints = []
if force_word:
constraints.append(PhrasalConstraint(
tokenizer(force_word, add_special_tokens=False).input_ids
))
if force_flexible_input:
constraints.append(DisjunctiveConstraint(
tokenizer(force_flexible, add_special_tokens=False).input_ids
))
if filter_mode:
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=True)])
else:
logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=False)])
input_ids = tokenizer.encode(sequence, return_tensors="pt").to('cuda')
generate_kwargs = ast.literal_eval(generate_args)
if constraints:
output_ids = model.generate(input_ids, constraints=constraints, logits_processor=logits_processor, **generate_kwargs)
else:
output_ids = model.generate(input_ids, logits_processor=logits_processor, **generate_kwargs)
st.write("GENERATED SEQUENCE(s): ")
for output in output_ids:
st.write(tokenizer.decode(output, skip_special_tokens = True, clean_up_tokenization_spaces = True))
|