Spaces:
Running
Running
arnocandel
commited on
Commit
•
24b4b28
1
Parent(s):
b43c18e
Update with h2oGPT hash e35e6ce906c57495ee80b1e3b8507ad374f6a50d
Browse files- finetune.py +20 -5
- generate.py +51 -6
- gradio_runner.py +3 -2
- requirements.txt +3 -3
finetune.py
CHANGED
@@ -30,6 +30,7 @@ class PromptType(Enum):
|
|
30 |
human_bot_orig = 9
|
31 |
prompt_answer = 10
|
32 |
open_assistant = 11
|
|
|
33 |
|
34 |
|
35 |
prompt_type_to_model_name = {
|
@@ -56,6 +57,8 @@ prompt_type_to_model_name = {
|
|
56 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
57 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
58 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
|
|
|
|
59 |
],
|
60 |
'instruct': [],
|
61 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
@@ -63,15 +66,18 @@ prompt_type_to_model_name = {
|
|
63 |
'human_bot': [
|
64 |
'h2oai/h2ogpt-oasst1-512-12b',
|
65 |
'h2oai/h2ogpt-oasst1-512-20b',
|
|
|
|
|
66 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
67 |
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
68 |
],
|
69 |
'dai_faq': [],
|
70 |
'summarize': [],
|
71 |
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
72 |
-
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
|
73 |
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
74 |
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
|
|
75 |
}
|
76 |
|
77 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
@@ -222,8 +228,6 @@ def train(
|
|
222 |
NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
|
223 |
|
224 |
CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
|
225 |
-
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
226 |
-
replace_llama_attn_with_flash_attn()
|
227 |
assert (
|
228 |
base_model
|
229 |
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
@@ -590,8 +594,8 @@ def train(
|
|
590 |
tokenizer=tokenizer,
|
591 |
train_dataset=train_data,
|
592 |
eval_dataset=valid_data,
|
593 |
-
#
|
594 |
-
args=transformers.
|
595 |
per_device_train_batch_size=micro_batch_size,
|
596 |
per_device_eval_batch_size=1,
|
597 |
eval_accumulation_steps=10,
|
@@ -901,6 +905,17 @@ Current Time: {}
|
|
901 |
eos = "</s>"
|
902 |
terminate_response = [start, PreResponse, pend, eos]
|
903 |
chat_sep = eos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
904 |
else:
|
905 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
906 |
|
|
|
30 |
human_bot_orig = 9
|
31 |
prompt_answer = 10
|
32 |
open_assistant = 11
|
33 |
+
wizard_lm = 12
|
34 |
|
35 |
|
36 |
prompt_type_to_model_name = {
|
|
|
57 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
58 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
59 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
60 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
61 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
62 |
],
|
63 |
'instruct': [],
|
64 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
|
|
66 |
'human_bot': [
|
67 |
'h2oai/h2ogpt-oasst1-512-12b',
|
68 |
'h2oai/h2ogpt-oasst1-512-20b',
|
69 |
+
'h2oai/h2ogpt-oig-oasst1-512-20b',
|
70 |
+
'h2oai/h2ogpt-oig-oasst1-512-12b',
|
71 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
72 |
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
73 |
],
|
74 |
'dai_faq': [],
|
75 |
'summarize': [],
|
76 |
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
77 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
|
78 |
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
79 |
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
80 |
+
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
81 |
}
|
82 |
|
83 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
|
|
228 |
NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
|
229 |
|
230 |
CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
|
|
|
|
|
231 |
assert (
|
232 |
base_model
|
233 |
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
|
|
594 |
tokenizer=tokenizer,
|
595 |
train_dataset=train_data,
|
596 |
eval_dataset=valid_data,
|
597 |
+
# FIXME: might need Seq2SeqTrainingArguments for some models
|
598 |
+
args=transformers.TrainingArguments(
|
599 |
per_device_train_batch_size=micro_batch_size,
|
600 |
per_device_eval_batch_size=1,
|
601 |
eval_accumulation_steps=10,
|
|
|
905 |
eos = "</s>"
|
906 |
terminate_response = [start, PreResponse, pend, eos]
|
907 |
chat_sep = eos
|
908 |
+
elif prompt_type in [12, "12", "wizard_lm"]:
|
909 |
+
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
910 |
+
preprompt = ''
|
911 |
+
start = ''
|
912 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
913 |
+
PreInstruct = ""
|
914 |
+
PreInput = None
|
915 |
+
PreResponse = "\n\n### Response"
|
916 |
+
eos = "</s>"
|
917 |
+
terminate_response = [PreResponse, eos]
|
918 |
+
chat_sep = eos
|
919 |
else:
|
920 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
921 |
|
generate.py
CHANGED
@@ -84,6 +84,7 @@ def main(
|
|
84 |
api_open: bool = False,
|
85 |
allow_api: bool = True,
|
86 |
input_lines: int = 1,
|
|
|
87 |
|
88 |
sanitize_user_prompt: bool = True,
|
89 |
sanitize_bot_response: bool = True,
|
@@ -145,6 +146,8 @@ def main(
|
|
145 |
:param api_open: If False, don't let API calls skip gradio queue
|
146 |
:param allow_api: whether to allow API calls at all to gradio server
|
147 |
:param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
|
|
|
|
|
148 |
:param sanitize_user_prompt: whether to remove profanity from user input
|
149 |
:param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
|
150 |
:param extra_model_options: extra models to show in list in gradio
|
@@ -211,7 +214,7 @@ def main(
|
|
211 |
if psutil.virtual_memory().available < 94*1024**3:
|
212 |
# 12B uses ~94GB
|
213 |
# 6.9B uses ~47GB
|
214 |
-
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b'
|
215 |
|
216 |
# get defaults
|
217 |
model_lower = base_model.lower()
|
@@ -881,13 +884,17 @@ def evaluate(
|
|
881 |
else:
|
882 |
gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
|
883 |
|
|
|
|
|
|
|
884 |
decoder = functools.partial(tokenizer.decode,
|
885 |
-
|
886 |
-
clean_up_tokenization_spaces=True,
|
887 |
)
|
|
|
|
|
|
|
888 |
decoder_raw = functools.partial(tokenizer.decode,
|
889 |
-
|
890 |
-
clean_up_tokenization_spaces=True,
|
891 |
)
|
892 |
|
893 |
with torch.no_grad():
|
@@ -915,14 +922,16 @@ def evaluate(
|
|
915 |
# some models specify special tokens that are part of normal prompt, so can't skip them
|
916 |
inputs_decoded = prompt = inputs_decoded_raw
|
917 |
decoder = decoder_raw
|
|
|
918 |
elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
|
919 |
inputs_decoded = prompt = inputs_decoded_raw
|
920 |
decoder = decoder_raw
|
|
|
921 |
else:
|
922 |
print("WARNING: Special characters in prompt", flush=True)
|
923 |
if stream_output:
|
924 |
skip_prompt = False
|
925 |
-
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False)
|
926 |
gen_kwargs.update(dict(streamer=streamer))
|
927 |
target_func = generate_with_exceptions
|
928 |
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
@@ -1312,3 +1321,39 @@ if __name__ == "__main__":
|
|
1312 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
|
1313 |
"""
|
1314 |
fire.Fire(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
api_open: bool = False,
|
85 |
allow_api: bool = True,
|
86 |
input_lines: int = 1,
|
87 |
+
auth: typing.List[typing.Tuple[str, str]] = None,
|
88 |
|
89 |
sanitize_user_prompt: bool = True,
|
90 |
sanitize_bot_response: bool = True,
|
|
|
146 |
:param api_open: If False, don't let API calls skip gradio queue
|
147 |
:param allow_api: whether to allow API calls at all to gradio server
|
148 |
:param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
|
149 |
+
:param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
|
150 |
+
e.g. --auth=[('jon','password')] with no spaces
|
151 |
:param sanitize_user_prompt: whether to remove profanity from user input
|
152 |
:param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
|
153 |
:param extra_model_options: extra models to show in list in gradio
|
|
|
214 |
if psutil.virtual_memory().available < 94*1024**3:
|
215 |
# 12B uses ~94GB
|
216 |
# 6.9B uses ~47GB
|
217 |
+
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' if not base_model else base_model
|
218 |
|
219 |
# get defaults
|
220 |
model_lower = base_model.lower()
|
|
|
884 |
else:
|
885 |
gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
|
886 |
|
887 |
+
decoder_kwargs = dict(skip_special_tokens=True,
|
888 |
+
clean_up_tokenization_spaces=True)
|
889 |
+
|
890 |
decoder = functools.partial(tokenizer.decode,
|
891 |
+
**decoder_kwargs
|
|
|
892 |
)
|
893 |
+
decoder_raw_kwargs = dict(skip_special_tokens=False,
|
894 |
+
clean_up_tokenization_spaces=True)
|
895 |
+
|
896 |
decoder_raw = functools.partial(tokenizer.decode,
|
897 |
+
**decoder_raw_kwargs
|
|
|
898 |
)
|
899 |
|
900 |
with torch.no_grad():
|
|
|
922 |
# some models specify special tokens that are part of normal prompt, so can't skip them
|
923 |
inputs_decoded = prompt = inputs_decoded_raw
|
924 |
decoder = decoder_raw
|
925 |
+
decoder_kwargs = decoder_raw_kwargs
|
926 |
elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
|
927 |
inputs_decoded = prompt = inputs_decoded_raw
|
928 |
decoder = decoder_raw
|
929 |
+
decoder_kwargs = decoder_raw_kwargs
|
930 |
else:
|
931 |
print("WARNING: Special characters in prompt", flush=True)
|
932 |
if stream_output:
|
933 |
skip_prompt = False
|
934 |
+
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
|
935 |
gen_kwargs.update(dict(streamer=streamer))
|
936 |
target_func = generate_with_exceptions
|
937 |
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
|
|
1321 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
|
1322 |
"""
|
1323 |
fire.Fire(main)
|
1324 |
+
|
1325 |
+
|
1326 |
+
import pytest
|
1327 |
+
|
1328 |
+
@pytest.mark.parametrize(
|
1329 |
+
"base_model",
|
1330 |
+
[
|
1331 |
+
"h2oai/h2ogpt-oig-oasst1-512-6.9b",
|
1332 |
+
"h2oai/h2ogpt-oig-oasst1-512-12b",
|
1333 |
+
"h2oai/h2ogpt-oig-oasst1-512-20b",
|
1334 |
+
"h2oai/h2ogpt-oasst1-512-12b",
|
1335 |
+
"h2oai/h2ogpt-oasst1-512-20b",
|
1336 |
+
"h2oai/h2ogpt-gm-oasst1-en-1024-20b",
|
1337 |
+
"databricks/dolly-v2-12b",
|
1338 |
+
"h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
|
1339 |
+
"ehartford/WizardLM-7B-Uncensored",
|
1340 |
+
"ehartford/WizardLM-13B-Uncensored",
|
1341 |
+
"AlekseyKorshuk/vicuna-7b",
|
1342 |
+
"TheBloke/stable-vicuna-13B-HF",
|
1343 |
+
"decapoda-research/llama-7b-hf",
|
1344 |
+
"decapoda-research/llama-13b-hf",
|
1345 |
+
"decapoda-research/llama-30b-hf",
|
1346 |
+
"junelee/wizard-vicuna-13b",
|
1347 |
+
]
|
1348 |
+
)
|
1349 |
+
def test_score_eval(base_model):
|
1350 |
+
main(
|
1351 |
+
base_model=base_model,
|
1352 |
+
chat=False,
|
1353 |
+
stream_output=False,
|
1354 |
+
gradio=False,
|
1355 |
+
eval_sharegpt_prompts_only=500,
|
1356 |
+
eval_sharegpt_as_output=False,
|
1357 |
+
num_beams=2,
|
1358 |
+
infer_devices=False,
|
1359 |
+
)
|
gradio_runner.py
CHANGED
@@ -50,7 +50,7 @@ def go_gradio(**kwargs):
|
|
50 |
"""
|
51 |
else:
|
52 |
description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
|
53 |
-
description += "If this host is busy, try [
|
54 |
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
|
55 |
|
56 |
if kwargs['verbose']:
|
@@ -921,7 +921,8 @@ def go_gradio(**kwargs):
|
|
921 |
scheduler.start()
|
922 |
|
923 |
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
924 |
-
favicon_path=favicon_path, prevent_thread_lock=True
|
|
|
925 |
print("Started GUI", flush=True)
|
926 |
if kwargs['block_gradio_exit']:
|
927 |
demo.block_thread()
|
|
|
50 |
"""
|
51 |
else:
|
52 |
description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
|
53 |
+
description += "If this host is busy, try [12B](https://gpt.h2o.ai), [30B](http://gpt2.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
54 |
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
|
55 |
|
56 |
if kwargs['verbose']:
|
|
|
921 |
scheduler.start()
|
922 |
|
923 |
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
924 |
+
favicon_path=favicon_path, prevent_thread_lock=True,
|
925 |
+
auth=kwargs['auth'])
|
926 |
print("Started GUI", flush=True)
|
927 |
if kwargs['block_gradio_exit']:
|
928 |
demo.block_thread()
|
requirements.txt
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
-
datasets==2.
|
3 |
sentencepiece==0.1.97
|
4 |
accelerate==0.18.0
|
5 |
gradio==3.27.0
|
6 |
-
huggingface_hub==0.
|
7 |
appdirs==1.4.4
|
8 |
fire==0.5.0
|
9 |
docutils==0.19
|
10 |
-
torch==2.0.
|
11 |
evaluate==0.4.0
|
12 |
rouge_score==0.1.2
|
13 |
sacrebleu==2.3.1
|
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
+
datasets==2.12.0
|
3 |
sentencepiece==0.1.97
|
4 |
accelerate==0.18.0
|
5 |
gradio==3.27.0
|
6 |
+
huggingface_hub==0.14.1
|
7 |
appdirs==1.4.4
|
8 |
fire==0.5.0
|
9 |
docutils==0.19
|
10 |
+
torch==2.0.1
|
11 |
evaluate==0.4.0
|
12 |
rouge_score==0.1.2
|
13 |
sacrebleu==2.3.1
|