Spaces:
Running
Running
pseudotensor
commited on
Commit
•
30e5d19
1
Parent(s):
9bcca78
Update with h2oGPT hash 03227623260f552fd7e2b8c51409308bc7242933
Browse files- client_test.py +42 -19
- create_data.py +60 -69
- finetune.py +7 -11
- generate.py +236 -242
- gpt4all_llm.py +162 -26
- gpt_langchain.py +561 -183
- gradio_runner.py +252 -110
- gradio_themes.py +41 -2
- h2oai_pipeline.py +96 -22
- prompter.py +119 -22
- requirements.txt +12 -11
- stopping.py +6 -4
- utils.py +83 -8
client_test.py
CHANGED
@@ -23,7 +23,7 @@ HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
|
|
23 |
Result:
|
24 |
|
25 |
Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
|
26 |
-
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.'}
|
27 |
|
28 |
|
29 |
For demo:
|
@@ -33,9 +33,15 @@ HOST="https://gpt.h2o.ai" python client_test.py
|
|
33 |
Result:
|
34 |
|
35 |
Loaded as API: https://gpt.h2o.ai ✔
|
36 |
-
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
"""
|
|
|
39 |
import time
|
40 |
import os
|
41 |
import markdown # pip install markdown
|
@@ -56,7 +62,7 @@ def get_client(serialize=True):
|
|
56 |
return client
|
57 |
|
58 |
|
59 |
-
def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50):
|
60 |
from collections import OrderedDict
|
61 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
62 |
iinput='', # only for chat=True
|
@@ -79,12 +85,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
79 |
chat=chat,
|
80 |
instruction_nochat=prompt if not chat else '',
|
81 |
iinput_nochat='', # only for chat=False
|
82 |
-
langchain_mode=
|
|
|
83 |
document_choice=['All'],
|
84 |
)
|
85 |
if chat:
|
86 |
# add chatbot output on end. Assumes serialize=False
|
87 |
-
kwargs.update(dict(chatbot=[
|
88 |
|
89 |
return kwargs, list(kwargs.values())
|
90 |
|
@@ -103,22 +110,29 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
|
103 |
*tuple(args),
|
104 |
api_name=api_name,
|
105 |
)
|
|
|
106 |
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
107 |
-
response=md_to_text(res))
|
|
|
108 |
print(res_dict)
|
109 |
return res_dict
|
110 |
|
111 |
|
112 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
113 |
def test_client_chat():
|
114 |
-
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50
|
115 |
-
|
116 |
|
117 |
-
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
|
118 |
-
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, max_new_tokens=max_new_tokens)
|
119 |
|
|
|
120 |
client = get_client(serialize=False)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
res = client.predict(*tuple(args), api_name='/instruction')
|
123 |
args[-1] += [res[-1]]
|
124 |
|
@@ -127,8 +141,8 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
|
|
127 |
if not kwargs['stream_output']:
|
128 |
res = client.predict(*tuple(args), api_name='/instruction_bot')
|
129 |
res_dict['response'] = res[0][-1][1]
|
130 |
-
print(md_to_text(res_dict['response']))
|
131 |
-
return res_dict
|
132 |
else:
|
133 |
job = client.submit(*tuple(args), api_name='/instruction_bot')
|
134 |
res1 = ''
|
@@ -137,15 +151,24 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
|
|
137 |
if outputs_list:
|
138 |
res = job.communicator.job.outputs[-1]
|
139 |
res1 = res[0][-1][-1]
|
140 |
-
res1 = md_to_text(res1)
|
141 |
print(res1)
|
142 |
time.sleep(0.1)
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
assert md is not None, "Markdown is None"
|
150 |
html = markdown.markdown(md)
|
151 |
soup = BeautifulSoup(html, features='html.parser')
|
|
|
23 |
Result:
|
24 |
|
25 |
Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
|
26 |
+
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
|
27 |
|
28 |
|
29 |
For demo:
|
|
|
33 |
Result:
|
34 |
|
35 |
Loaded as API: https://gpt.h2o.ai ✔
|
36 |
+
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
|
37 |
+
|
38 |
+
NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
|
39 |
+
|
40 |
+
{'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
|
41 |
+
|
42 |
|
43 |
"""
|
44 |
+
import ast
|
45 |
import time
|
46 |
import os
|
47 |
import markdown # pip install markdown
|
|
|
62 |
return client
|
63 |
|
64 |
|
65 |
+
def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50, langchain_mode='Disabled'):
|
66 |
from collections import OrderedDict
|
67 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
68 |
iinput='', # only for chat=True
|
|
|
85 |
chat=chat,
|
86 |
instruction_nochat=prompt if not chat else '',
|
87 |
iinput_nochat='', # only for chat=False
|
88 |
+
langchain_mode=langchain_mode,
|
89 |
+
top_k_docs=4,
|
90 |
document_choice=['All'],
|
91 |
)
|
92 |
if chat:
|
93 |
# add chatbot output on end. Assumes serialize=False
|
94 |
+
kwargs.update(dict(chatbot=[]))
|
95 |
|
96 |
return kwargs, list(kwargs.values())
|
97 |
|
|
|
110 |
*tuple(args),
|
111 |
api_name=api_name,
|
112 |
)
|
113 |
+
print("Raw client result: %s" % res, flush=True)
|
114 |
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
115 |
+
response=md_to_text(ast.literal_eval(res)['response']),
|
116 |
+
sources=ast.literal_eval(res)['sources'])
|
117 |
print(res_dict)
|
118 |
return res_dict
|
119 |
|
120 |
|
121 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
122 |
def test_client_chat():
|
123 |
+
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50,
|
124 |
+
langchain_mode='Disabled')
|
125 |
|
|
|
|
|
126 |
|
127 |
+
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
|
128 |
client = get_client(serialize=False)
|
129 |
|
130 |
+
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
131 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
|
132 |
+
return run_client(client, prompt, args, kwargs)
|
133 |
+
|
134 |
+
|
135 |
+
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
136 |
res = client.predict(*tuple(args), api_name='/instruction')
|
137 |
args[-1] += [res[-1]]
|
138 |
|
|
|
141 |
if not kwargs['stream_output']:
|
142 |
res = client.predict(*tuple(args), api_name='/instruction_bot')
|
143 |
res_dict['response'] = res[0][-1][1]
|
144 |
+
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
|
145 |
+
return res_dict, client
|
146 |
else:
|
147 |
job = client.submit(*tuple(args), api_name='/instruction_bot')
|
148 |
res1 = ''
|
|
|
151 |
if outputs_list:
|
152 |
res = job.communicator.job.outputs[-1]
|
153 |
res1 = res[0][-1][-1]
|
154 |
+
res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
|
155 |
print(res1)
|
156 |
time.sleep(0.1)
|
157 |
+
full_outputs = job.outputs()
|
158 |
+
if verbose:
|
159 |
+
print('job.outputs: %s' % str(full_outputs))
|
160 |
+
# ensure get ending to avoid race
|
161 |
+
# -1 means last response if streaming
|
162 |
+
# 0 means get text_output, ignore exception_text
|
163 |
+
# 0 means get list within text_output that looks like [[prompt], [answer]]
|
164 |
+
# 1 means get bot answer, so will have last bot answer
|
165 |
+
res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
|
166 |
+
return res_dict, client
|
167 |
+
|
168 |
+
|
169 |
+
def md_to_text(md, do_md_to_text=True):
|
170 |
+
if not do_md_to_text:
|
171 |
+
return md
|
172 |
assert md is not None, "Markdown is None"
|
173 |
html = markdown.markdown(md)
|
174 |
soup = BeautifulSoup(html, features='html.parser')
|
create_data.py
CHANGED
@@ -23,7 +23,7 @@ import pandas as pd
|
|
23 |
import numpy as np
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
-
from utils import flatten_list
|
27 |
|
28 |
|
29 |
def parse_rst_file(filepath):
|
@@ -184,7 +184,7 @@ def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
|
|
184 |
return dst
|
185 |
|
186 |
|
187 |
-
def rst_to_outputs(files, min_len=30, max_len=2048//2 - 30):
|
188 |
# account for sequence length (context window) including prompt and input and output
|
189 |
|
190 |
# os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
|
@@ -274,22 +274,6 @@ def test_scrape_dai_docs_all_pandoc():
|
|
274 |
f.write(json.dumps(save_thing, indent=2))
|
275 |
|
276 |
|
277 |
-
def remove(path: str):
|
278 |
-
try:
|
279 |
-
if path is not None and os.path.exists(path):
|
280 |
-
if os.path.isdir(path):
|
281 |
-
shutil_rmtree(path, ignore_errors=True)
|
282 |
-
else:
|
283 |
-
with contextlib.suppress(FileNotFoundError):
|
284 |
-
os.remove(path)
|
285 |
-
except:
|
286 |
-
pass
|
287 |
-
|
288 |
-
|
289 |
-
def shutil_rmtree(*args, **kwargs):
|
290 |
-
return shutil.rmtree(*args, **kwargs)
|
291 |
-
|
292 |
-
|
293 |
def test_config_to_json():
|
294 |
"""
|
295 |
Needs to run from Driverless AI source directory.
|
@@ -310,15 +294,18 @@ def test_config_to_json():
|
|
310 |
[
|
311 |
{
|
312 |
'prompt_type': 'plain',
|
313 |
-
'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
|
|
314 |
},
|
315 |
{
|
316 |
'prompt_type': 'plain',
|
317 |
-
'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
|
|
318 |
},
|
319 |
{
|
320 |
'prompt_type': 'plain',
|
321 |
-
'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
|
|
|
322 |
} if title and comment else None,
|
323 |
{
|
324 |
'prompt_type': 'human_bot',
|
@@ -420,7 +407,8 @@ def test_prep_instruct_vicuna():
|
|
420 |
from datasets import load_dataset
|
421 |
filename = 'ShareGPT_unfiltered_cleaned_split.json'
|
422 |
if not os.path.exists(filename):
|
423 |
-
os.system(
|
|
|
424 |
data = load_dataset("json", data_files={"train": filename})["train"]
|
425 |
training_rows = []
|
426 |
for i in range(data.num_rows):
|
@@ -440,6 +428,7 @@ def test_prep_instruct_vicuna():
|
|
440 |
with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
|
441 |
f.write(json.dumps(training_rows, indent=2))
|
442 |
|
|
|
443 |
POSTFIX = ".generate_human_bot.train_plain.json"
|
444 |
|
445 |
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
@@ -497,10 +486,10 @@ useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
|
|
497 |
'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
|
498 |
'unified_merged_code_xp3.jsonl.parquet',
|
499 |
'unified_multi_news.jsonl.parquet',
|
500 |
-
#'unified_multi_sum.jsonl.parquet'
|
501 |
'unified_ni.jsonl.gz.parquet',
|
502 |
'unified_openai_summarize_tldr.jsonl.parquet',
|
503 |
-
#'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
|
504 |
'unified_plot_screenplay_books_dialog.jsonl.parquet',
|
505 |
'unified_soda_dialog.jsonl.parquet',
|
506 |
'unified_unnatural_instructions.jsonl.parquet',
|
@@ -546,8 +535,8 @@ def test_merge_shuffle_small_sample_oig_data():
|
|
546 |
|
547 |
def test_join_jsons():
|
548 |
files = ['config.json'] * 1 + \
|
549 |
-
|
550 |
-
|
551 |
print(files)
|
552 |
lst = []
|
553 |
[lst.extend(json.load(open(fil, 'rt'))) for fil in files]
|
@@ -570,11 +559,10 @@ def test_make_rlhf_good_data(filename):
|
|
570 |
f.write(json.dumps(new_rows, indent=2))
|
571 |
|
572 |
|
573 |
-
|
574 |
def test_show_prompts():
|
575 |
files = ['config.json'] * 1 + \
|
576 |
-
|
577 |
-
|
578 |
file_points = [json.load(open(fil, 'rt')) for fil in files]
|
579 |
from prompter import generate_prompt
|
580 |
for data_points in file_points:
|
@@ -600,7 +588,7 @@ def test_get_open_datasets():
|
|
600 |
'license:openrail++',
|
601 |
'license:openrail',
|
602 |
'license:bigscience-bloom-rail-1.0',
|
603 |
-
#'license:agpl-3.0',
|
604 |
'license:other',
|
605 |
'license:unknown',
|
606 |
# 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
|
@@ -610,13 +598,13 @@ def test_get_open_datasets():
|
|
610 |
'license:cc-by-3.0',
|
611 |
'license:cc-by-2.0',
|
612 |
'license:cc-by-2.5',
|
613 |
-
#'license:cc-by-sa-4.0', # would require same license
|
614 |
'license:odbl',
|
615 |
'license:pddl',
|
616 |
'license:ms-pl',
|
617 |
'license:zlib',
|
618 |
]
|
619 |
-
|
620 |
|
621 |
from huggingface_hub import list_datasets
|
622 |
datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
|
@@ -656,12 +644,12 @@ def test_get_open_datasets():
|
|
656 |
'language:' not in str(x.tags) or
|
657 |
'language:en' in str(x.tags)]
|
658 |
small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
# 'aeslc' : email_body, subject -> summarization?
|
666 |
# load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
|
667 |
ids = [x.id for x in small_open_english_tasked_datasets]
|
@@ -689,7 +677,8 @@ def test_get_open_datasets():
|
|
689 |
'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
|
690 |
'Jeska/vaccinchat', # not useful
|
691 |
'alespalla/chatbot_instruction_prompts', # mixes alpaca
|
692 |
-
'allenai/prosocial-dialog',
|
|
|
693 |
'AlekseyKorshuk/persona-chat', # low quality
|
694 |
'bavard/personachat_truecased', # low quality
|
695 |
'adamlin/daily_dialog', # medium quality conversations
|
@@ -724,7 +713,8 @@ def test_get_open_datasets():
|
|
724 |
# some ids clearly speech related
|
725 |
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
|
726 |
# HF testing
|
727 |
-
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
|
|
728 |
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
729 |
'chinese' not in x.id]
|
730 |
|
@@ -738,7 +728,6 @@ def test_get_open_datasets():
|
|
738 |
# grep "pip install" getdata9.log
|
739 |
# NOTE: Some datasets have default config, but others are there. Don't know how to access them.
|
740 |
|
741 |
-
|
742 |
"""
|
743 |
https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
|
744 |
https://github.com/mahnazkoupaee/WikiHow-Dataset
|
@@ -773,7 +762,7 @@ def test_get_open_datasets():
|
|
773 |
def do_one(data_id, num_downloads):
|
774 |
from datasets import load_dataset
|
775 |
out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
|
776 |
-
if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024**3:
|
777 |
return
|
778 |
try:
|
779 |
print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
|
@@ -881,23 +870,21 @@ useful = ['Dahoas/instruct-human-assistant-prompt',
|
|
881 |
'lmqg/qg_squad', # context QA
|
882 |
'lmqg/qg_squadshifts', # context QA
|
883 |
'lmqg/qg_subjqa', # context QA
|
884 |
-
'pszemraj/HC3-textgen-qa',
|
|
|
885 |
'pythonist/newdata', # long context, QA, brief A
|
886 |
'ropes', # long background, situation, question, A
|
887 |
'wikitablequestions', # table -> QA
|
888 |
'bigscience/p3', # context QA but short answers
|
889 |
]
|
890 |
|
891 |
-
|
892 |
-
|
893 |
code_useful = ['0n1xus/codexglue',
|
894 |
'openai_humaneval',
|
895 |
'koutch/staqc',
|
896 |
]
|
897 |
|
898 |
-
|
899 |
maybe_useful = ['AlekseyKorshuk/comedy-scripts',
|
900 |
-
|
901 |
'qed', # reasonable QA, but low reasoning
|
902 |
'selqa', # candidate answers
|
903 |
'HuggingFaceH4/instruction-pilot-outputs-filtered',
|
@@ -905,7 +892,6 @@ maybe_useful = ['AlekseyKorshuk/comedy-scripts',
|
|
905 |
'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
|
906 |
]
|
907 |
|
908 |
-
|
909 |
summary_useful = ['austin/rheum_abstracts',
|
910 |
'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
|
911 |
'CarperAI/openai_summarize_tldr', # summarize QA
|
@@ -928,14 +914,12 @@ summary_useful = ['austin/rheum_abstracts',
|
|
928 |
'stacked-summaries/stacked-xsum-1024',
|
929 |
]
|
930 |
|
931 |
-
|
932 |
math_useful = [
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
|
937 |
skipped = ['c4', # maybe useful, used for flan, but skipped due to size
|
938 |
-
|
939 |
|
940 |
"""
|
941 |
To get training data from oig:
|
@@ -958,14 +942,14 @@ def test_assemble_and_detox():
|
|
958 |
text_list = df[['text']].values.ravel().tolist()
|
959 |
new_text = []
|
960 |
max_len = 2048 # uber cutoff
|
961 |
-
MAX_LEN = 2048//2 - 30 # max len per question/answer
|
962 |
for text in tqdm(text_list):
|
963 |
human_starts = [m.start() for m in re.finditer('<human>: ', text)]
|
964 |
if len(human_starts) == 1:
|
965 |
human_starts = [0, len(text)] # always go into for loop below
|
966 |
blurb = ''
|
967 |
for i in range(len(human_starts) - 1):
|
968 |
-
interaction = text[human_starts[i]: human_starts[i+1]][:max_len]
|
969 |
blurb += interaction
|
970 |
if len(blurb) >= MAX_LEN:
|
971 |
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
@@ -1002,17 +986,17 @@ def test_basic_cleaning():
|
|
1002 |
from profanity_check import predict
|
1003 |
df_list = []
|
1004 |
for data in useful_oig_files:
|
1005 |
-
|
1006 |
-
|
1007 |
print("Processing %s" % data, flush=True)
|
1008 |
df = pd.read_parquet(data)
|
1009 |
df = df.reset_index(drop=True)
|
1010 |
# NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
|
1011 |
-
#avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
|
1012 |
-
df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot))/2.0)
|
1013 |
df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
|
1014 |
-
#df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
|
1015 |
-
#low_quality_patterns = ['Write the rest of this wikipedia article']
|
1016 |
res = predict(df['text'])
|
1017 |
df['bad_words'] = res
|
1018 |
df = df.reset_index(drop=True)
|
@@ -1215,7 +1199,7 @@ def count_human_bot_lengths(df, human=None, bot=None):
|
|
1215 |
assert len(text)
|
1216 |
list_what = []
|
1217 |
for ii in range(len(starts) - 1):
|
1218 |
-
interaction = text[starts[ii]: starts[ii+1]]
|
1219 |
if other in interaction:
|
1220 |
interaction = interaction[:interaction.find(other)]
|
1221 |
interaction.strip()
|
@@ -1416,9 +1400,13 @@ def test_add_open_assistant(fixup_personality, only_personality, deberta_grading
|
|
1416 |
conv2['message_id'] = None
|
1417 |
conversations = [c for c in conversations if c['message_id']]
|
1418 |
if only_personality:
|
1419 |
-
all_rows.extend(
|
|
|
|
|
1420 |
else:
|
1421 |
-
all_rows.extend(
|
|
|
|
|
1422 |
unhelpful = get_unhelpful_list()
|
1423 |
all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
|
1424 |
personality = create_personality_data()
|
@@ -1484,6 +1472,7 @@ def test_finalize_to_json():
|
|
1484 |
n_jobs=-1,
|
1485 |
)
|
1486 |
return df[(df['profanity'] == 0)].reset_index(drop=True)
|
|
|
1487 |
print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1488 |
df = final_clean(df)
|
1489 |
print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
@@ -1721,7 +1710,7 @@ def test_check_unhelpful():
|
|
1721 |
# file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
|
1722 |
|
1723 |
unhelpful = get_unhelpful_list()
|
1724 |
-
#data = json.load(open(file, 'rt'))
|
1725 |
df = pd.read_json(file)
|
1726 |
|
1727 |
use_reward_score_threshold = False
|
@@ -1733,7 +1722,7 @@ def test_check_unhelpful():
|
|
1733 |
from nltk.translate.bleu_score import sentence_bleu
|
1734 |
|
1735 |
def get_bleu(actual, expected_list):
|
1736 |
-
#return bleu.sentence_score(actual, expected_list).score
|
1737 |
return sentence_bleu(expected_list, actual)
|
1738 |
|
1739 |
threshold = 0.0
|
@@ -1770,12 +1759,13 @@ def test_check_unhelpful():
|
|
1770 |
# pip install sentence_transformers-2.2.2
|
1771 |
from sentence_transformers import SentenceTransformer
|
1772 |
# sent_model = 'bert-base-nli-mean-tokens'
|
1773 |
-
#sent_model = 'nli-distilroberta-base-v2'
|
1774 |
sent_model = 'all-MiniLM-L6-v2'
|
1775 |
model = SentenceTransformer(sent_model)
|
1776 |
sentence_embeddings = model.encode(unhelpful)
|
1777 |
from sklearn.metrics.pairwise import cosine_similarity
|
1778 |
-
bots = [x for x in tqdm(bots) if
|
|
|
1779 |
|
1780 |
bads_bots = {}
|
1781 |
string_all = str(bots)
|
@@ -1787,7 +1777,8 @@ def test_check_unhelpful():
|
|
1787 |
pp.pprint(bads_bots)
|
1788 |
|
1789 |
total_bads_bots = sum(list(bads_bots.values()))
|
1790 |
-
print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
|
|
|
1791 |
|
1792 |
# assert len(bads) == 0, bads
|
1793 |
assert len(bads_bots) == 0, bads_bots
|
|
|
23 |
import numpy as np
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
+
from utils import flatten_list, remove
|
27 |
|
28 |
|
29 |
def parse_rst_file(filepath):
|
|
|
184 |
return dst
|
185 |
|
186 |
|
187 |
+
def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30):
|
188 |
# account for sequence length (context window) including prompt and input and output
|
189 |
|
190 |
# os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
|
|
|
274 |
f.write(json.dumps(save_thing, indent=2))
|
275 |
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
def test_config_to_json():
|
278 |
"""
|
279 |
Needs to run from Driverless AI source directory.
|
|
|
294 |
[
|
295 |
{
|
296 |
'prompt_type': 'plain',
|
297 |
+
'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
298 |
+
"\n", ""),
|
299 |
},
|
300 |
{
|
301 |
'prompt_type': 'plain',
|
302 |
+
'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
303 |
+
"\n", ""),
|
304 |
},
|
305 |
{
|
306 |
'prompt_type': 'plain',
|
307 |
+
'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
|
308 |
+
"\n", ""),
|
309 |
} if title and comment else None,
|
310 |
{
|
311 |
'prompt_type': 'human_bot',
|
|
|
407 |
from datasets import load_dataset
|
408 |
filename = 'ShareGPT_unfiltered_cleaned_split.json'
|
409 |
if not os.path.exists(filename):
|
410 |
+
os.system(
|
411 |
+
'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
|
412 |
data = load_dataset("json", data_files={"train": filename})["train"]
|
413 |
training_rows = []
|
414 |
for i in range(data.num_rows):
|
|
|
428 |
with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
|
429 |
f.write(json.dumps(training_rows, indent=2))
|
430 |
|
431 |
+
|
432 |
POSTFIX = ".generate_human_bot.train_plain.json"
|
433 |
|
434 |
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
|
|
486 |
'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
|
487 |
'unified_merged_code_xp3.jsonl.parquet',
|
488 |
'unified_multi_news.jsonl.parquet',
|
489 |
+
# 'unified_multi_sum.jsonl.parquet'
|
490 |
'unified_ni.jsonl.gz.parquet',
|
491 |
'unified_openai_summarize_tldr.jsonl.parquet',
|
492 |
+
# 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
|
493 |
'unified_plot_screenplay_books_dialog.jsonl.parquet',
|
494 |
'unified_soda_dialog.jsonl.parquet',
|
495 |
'unified_unnatural_instructions.jsonl.parquet',
|
|
|
535 |
|
536 |
def test_join_jsons():
|
537 |
files = ['config.json'] * 1 + \
|
538 |
+
['dai_docs.train_cleaned.json'] * 2 + \
|
539 |
+
['dai_faq.json'] * 3
|
540 |
print(files)
|
541 |
lst = []
|
542 |
[lst.extend(json.load(open(fil, 'rt'))) for fil in files]
|
|
|
559 |
f.write(json.dumps(new_rows, indent=2))
|
560 |
|
561 |
|
|
|
562 |
def test_show_prompts():
|
563 |
files = ['config.json'] * 1 + \
|
564 |
+
['dai_docs.train_cleaned.json'] * 1 + \
|
565 |
+
['dai_faq.json'] * 1
|
566 |
file_points = [json.load(open(fil, 'rt')) for fil in files]
|
567 |
from prompter import generate_prompt
|
568 |
for data_points in file_points:
|
|
|
588 |
'license:openrail++',
|
589 |
'license:openrail',
|
590 |
'license:bigscience-bloom-rail-1.0',
|
591 |
+
# 'license:agpl-3.0',
|
592 |
'license:other',
|
593 |
'license:unknown',
|
594 |
# 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
|
|
|
598 |
'license:cc-by-3.0',
|
599 |
'license:cc-by-2.0',
|
600 |
'license:cc-by-2.5',
|
601 |
+
# 'license:cc-by-sa-4.0', # would require same license
|
602 |
'license:odbl',
|
603 |
'license:pddl',
|
604 |
'license:ms-pl',
|
605 |
'license:zlib',
|
606 |
]
|
607 |
+
# bad license: cc-by-nc-4.0
|
608 |
|
609 |
from huggingface_hub import list_datasets
|
610 |
datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
|
|
|
644 |
'language:' not in str(x.tags) or
|
645 |
'language:en' in str(x.tags)]
|
646 |
small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
|
647 |
+
'n<1K' in str(x.tags) or
|
648 |
+
'1K<n<10K' in str(x.tags) or
|
649 |
+
'1K0<n<100K' in str(x.tags) or
|
650 |
+
'100K<n<1M' in str(x.tags) or
|
651 |
+
'size_category' not in str(x.tags)
|
652 |
+
]
|
653 |
# 'aeslc' : email_body, subject -> summarization?
|
654 |
# load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
|
655 |
ids = [x.id for x in small_open_english_tasked_datasets]
|
|
|
677 |
'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
|
678 |
'Jeska/vaccinchat', # not useful
|
679 |
'alespalla/chatbot_instruction_prompts', # mixes alpaca
|
680 |
+
'allenai/prosocial-dialog',
|
681 |
+
# already exlucded, but wrongly in other datasets that say more permissive license
|
682 |
'AlekseyKorshuk/persona-chat', # low quality
|
683 |
'bavard/personachat_truecased', # low quality
|
684 |
'adamlin/daily_dialog', # medium quality conversations
|
|
|
713 |
# some ids clearly speech related
|
714 |
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
|
715 |
# HF testing
|
716 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
717 |
+
'hf-internal-testing' not in x.id]
|
718 |
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
719 |
'chinese' not in x.id]
|
720 |
|
|
|
728 |
# grep "pip install" getdata9.log
|
729 |
# NOTE: Some datasets have default config, but others are there. Don't know how to access them.
|
730 |
|
|
|
731 |
"""
|
732 |
https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
|
733 |
https://github.com/mahnazkoupaee/WikiHow-Dataset
|
|
|
762 |
def do_one(data_id, num_downloads):
|
763 |
from datasets import load_dataset
|
764 |
out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
|
765 |
+
if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3:
|
766 |
return
|
767 |
try:
|
768 |
print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
|
|
|
870 |
'lmqg/qg_squad', # context QA
|
871 |
'lmqg/qg_squadshifts', # context QA
|
872 |
'lmqg/qg_subjqa', # context QA
|
873 |
+
'pszemraj/HC3-textgen-qa',
|
874 |
+
# QA medium, has human responses -- humans tend to provide links instead of trying to answer
|
875 |
'pythonist/newdata', # long context, QA, brief A
|
876 |
'ropes', # long background, situation, question, A
|
877 |
'wikitablequestions', # table -> QA
|
878 |
'bigscience/p3', # context QA but short answers
|
879 |
]
|
880 |
|
|
|
|
|
881 |
code_useful = ['0n1xus/codexglue',
|
882 |
'openai_humaneval',
|
883 |
'koutch/staqc',
|
884 |
]
|
885 |
|
|
|
886 |
maybe_useful = ['AlekseyKorshuk/comedy-scripts',
|
887 |
+
'openbookqa', # hard to parse, low reasoning
|
888 |
'qed', # reasonable QA, but low reasoning
|
889 |
'selqa', # candidate answers
|
890 |
'HuggingFaceH4/instruction-pilot-outputs-filtered',
|
|
|
892 |
'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
|
893 |
]
|
894 |
|
|
|
895 |
summary_useful = ['austin/rheum_abstracts',
|
896 |
'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
|
897 |
'CarperAI/openai_summarize_tldr', # summarize QA
|
|
|
914 |
'stacked-summaries/stacked-xsum-1024',
|
915 |
]
|
916 |
|
|
|
917 |
math_useful = [
|
918 |
+
'competition_math'
|
919 |
+
]
|
|
|
920 |
|
921 |
skipped = ['c4', # maybe useful, used for flan, but skipped due to size
|
922 |
+
]
|
923 |
|
924 |
"""
|
925 |
To get training data from oig:
|
|
|
942 |
text_list = df[['text']].values.ravel().tolist()
|
943 |
new_text = []
|
944 |
max_len = 2048 # uber cutoff
|
945 |
+
MAX_LEN = 2048 // 2 - 30 # max len per question/answer
|
946 |
for text in tqdm(text_list):
|
947 |
human_starts = [m.start() for m in re.finditer('<human>: ', text)]
|
948 |
if len(human_starts) == 1:
|
949 |
human_starts = [0, len(text)] # always go into for loop below
|
950 |
blurb = ''
|
951 |
for i in range(len(human_starts) - 1):
|
952 |
+
interaction = text[human_starts[i]: human_starts[i + 1]][:max_len]
|
953 |
blurb += interaction
|
954 |
if len(blurb) >= MAX_LEN:
|
955 |
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
|
|
986 |
from profanity_check import predict
|
987 |
df_list = []
|
988 |
for data in useful_oig_files:
|
989 |
+
# for data in useful_oig_files[:5]:
|
990 |
+
# for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
|
991 |
print("Processing %s" % data, flush=True)
|
992 |
df = pd.read_parquet(data)
|
993 |
df = df.reset_index(drop=True)
|
994 |
# NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
|
995 |
+
# avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
|
996 |
+
df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0)
|
997 |
df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
|
998 |
+
# df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
|
999 |
+
# low_quality_patterns = ['Write the rest of this wikipedia article']
|
1000 |
res = predict(df['text'])
|
1001 |
df['bad_words'] = res
|
1002 |
df = df.reset_index(drop=True)
|
|
|
1199 |
assert len(text)
|
1200 |
list_what = []
|
1201 |
for ii in range(len(starts) - 1):
|
1202 |
+
interaction = text[starts[ii]: starts[ii + 1]]
|
1203 |
if other in interaction:
|
1204 |
interaction = interaction[:interaction.find(other)]
|
1205 |
interaction.strip()
|
|
|
1400 |
conv2['message_id'] = None
|
1401 |
conversations = [c for c in conversations if c['message_id']]
|
1402 |
if only_personality:
|
1403 |
+
all_rows.extend(
|
1404 |
+
[dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
|
1405 |
+
'h2oGPT' in c['text']])
|
1406 |
else:
|
1407 |
+
all_rows.extend(
|
1408 |
+
[dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
|
1409 |
+
"What is H2O.ai" not in c['text']])
|
1410 |
unhelpful = get_unhelpful_list()
|
1411 |
all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
|
1412 |
personality = create_personality_data()
|
|
|
1472 |
n_jobs=-1,
|
1473 |
)
|
1474 |
return df[(df['profanity'] == 0)].reset_index(drop=True)
|
1475 |
+
|
1476 |
print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1477 |
df = final_clean(df)
|
1478 |
print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
|
|
1710 |
# file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
|
1711 |
|
1712 |
unhelpful = get_unhelpful_list()
|
1713 |
+
# data = json.load(open(file, 'rt'))
|
1714 |
df = pd.read_json(file)
|
1715 |
|
1716 |
use_reward_score_threshold = False
|
|
|
1722 |
from nltk.translate.bleu_score import sentence_bleu
|
1723 |
|
1724 |
def get_bleu(actual, expected_list):
|
1725 |
+
# return bleu.sentence_score(actual, expected_list).score
|
1726 |
return sentence_bleu(expected_list, actual)
|
1727 |
|
1728 |
threshold = 0.0
|
|
|
1759 |
# pip install sentence_transformers-2.2.2
|
1760 |
from sentence_transformers import SentenceTransformer
|
1761 |
# sent_model = 'bert-base-nli-mean-tokens'
|
1762 |
+
# sent_model = 'nli-distilroberta-base-v2'
|
1763 |
sent_model = 'all-MiniLM-L6-v2'
|
1764 |
model = SentenceTransformer(sent_model)
|
1765 |
sentence_embeddings = model.encode(unhelpful)
|
1766 |
from sklearn.metrics.pairwise import cosine_similarity
|
1767 |
+
bots = [x for x in tqdm(bots) if
|
1768 |
+
np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
|
1769 |
|
1770 |
bads_bots = {}
|
1771 |
string_all = str(bots)
|
|
|
1777 |
pp.pprint(bads_bots)
|
1778 |
|
1779 |
total_bads_bots = sum(list(bads_bots.values()))
|
1780 |
+
print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
|
1781 |
+
threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
|
1782 |
|
1783 |
# assert len(bads) == 0, bads
|
1784 |
assert len(bads_bots) == 0, bads_bots
|
finetune.py
CHANGED
@@ -65,7 +65,8 @@ def train(
|
|
65 |
micro_batch_size: int = 4,
|
66 |
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
67 |
fp16=True,
|
68 |
-
train_8bit=
|
|
|
69 |
|
70 |
# general training hyperparams
|
71 |
num_epochs: float = 1,
|
@@ -185,10 +186,12 @@ def train(
|
|
185 |
model = model_loader.from_pretrained(
|
186 |
base_model,
|
187 |
load_in_8bit=train_8bit,
|
|
|
188 |
device_map=device_map,
|
189 |
torch_dtype=torch.float16,
|
190 |
max_memory=max_memory,
|
191 |
local_files_only=local_files_only,
|
|
|
192 |
resume_download=resume_download,
|
193 |
use_auth_token=use_auth_token,
|
194 |
)
|
@@ -200,19 +203,12 @@ def train(
|
|
200 |
|
201 |
tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
|
202 |
|
203 |
-
if train_8bit:
|
204 |
from peft import (
|
205 |
-
|
206 |
)
|
207 |
|
208 |
-
|
209 |
-
model = prepare_model_for_int8_training(model)
|
210 |
-
else:
|
211 |
-
model = prepare_model_for_int8_training(
|
212 |
-
model,
|
213 |
-
output_embedding_layer_name="embed_out", # keep output logits in float32
|
214 |
-
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
215 |
-
)
|
216 |
|
217 |
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
218 |
try:
|
|
|
65 |
micro_batch_size: int = 4,
|
66 |
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
67 |
fp16=True,
|
68 |
+
train_8bit=False,
|
69 |
+
train_4bit=False,
|
70 |
|
71 |
# general training hyperparams
|
72 |
num_epochs: float = 1,
|
|
|
186 |
model = model_loader.from_pretrained(
|
187 |
base_model,
|
188 |
load_in_8bit=train_8bit,
|
189 |
+
load_in_4bit=train_4bit,
|
190 |
device_map=device_map,
|
191 |
torch_dtype=torch.float16,
|
192 |
max_memory=max_memory,
|
193 |
local_files_only=local_files_only,
|
194 |
+
trust_remote_code=True,
|
195 |
resume_download=resume_download,
|
196 |
use_auth_token=use_auth_token,
|
197 |
)
|
|
|
203 |
|
204 |
tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
|
205 |
|
206 |
+
if train_8bit or train_4bit:
|
207 |
from peft import (
|
208 |
+
prepare_model_for_kbit_training,
|
209 |
)
|
210 |
|
211 |
+
model = prepare_model_for_kbit_training(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
214 |
try:
|
generate.py
CHANGED
@@ -9,24 +9,25 @@ import os
|
|
9 |
import time
|
10 |
import traceback
|
11 |
import typing
|
|
|
12 |
from datetime import datetime
|
13 |
import filelock
|
14 |
import psutil
|
15 |
|
|
|
|
|
|
|
|
|
16 |
from loaders import get_loaders
|
17 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
18 |
-
import_matplotlib, get_device, makedirs
|
19 |
|
20 |
import_matplotlib()
|
21 |
-
from matplotlib import pyplot as plt
|
22 |
|
23 |
SEED = 1236
|
24 |
set_seed(SEED)
|
25 |
|
26 |
-
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
27 |
from typing import Union
|
28 |
-
import numpy as np
|
29 |
-
import pandas as pd
|
30 |
|
31 |
import fire
|
32 |
import torch
|
@@ -34,7 +35,7 @@ from peft import PeftModel
|
|
34 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
35 |
from accelerate import init_empty_weights, infer_auto_device_map
|
36 |
|
37 |
-
from prompter import Prompter, inv_prompt_type_to_model_lower
|
38 |
from stopping import get_stopping
|
39 |
|
40 |
eval_extra_columns = ['prompt', 'response', 'score']
|
@@ -47,12 +48,14 @@ scratch_base_dir = '/tmp/'
|
|
47 |
|
48 |
def main(
|
49 |
load_8bit: bool = False,
|
|
|
50 |
load_half: bool = True,
|
51 |
infer_devices: bool = True,
|
52 |
base_model: str = '',
|
53 |
tokenizer_base_model: str = '',
|
54 |
lora_weights: str = "",
|
55 |
gpu_id: int = 0,
|
|
|
56 |
|
57 |
prompt_type: Union[int, str] = None,
|
58 |
# input to generation
|
@@ -68,6 +71,7 @@ def main(
|
|
68 |
early_stopping: Union[bool, str] = None,
|
69 |
max_time: float = None,
|
70 |
|
|
|
71 |
debug: bool = False,
|
72 |
save_dir: str = None,
|
73 |
share: bool = True,
|
@@ -80,15 +84,18 @@ def main(
|
|
80 |
src_lang: str = "English",
|
81 |
tgt_lang: str = "Russian",
|
82 |
|
|
|
|
|
83 |
gradio: bool = True,
|
84 |
gradio_avoid_processing_markdown: bool = False,
|
|
|
85 |
chat: bool = True,
|
86 |
chat_context: bool = False,
|
87 |
stream_output: bool = True,
|
88 |
show_examples: bool = None,
|
89 |
verbose: bool = False,
|
90 |
-
h2ocolors: bool =
|
91 |
-
height: int =
|
92 |
show_lora: bool = True,
|
93 |
login_mode_if_model0: bool = False,
|
94 |
block_gradio_exit: bool = True,
|
@@ -107,13 +114,16 @@ def main(
|
|
107 |
score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
|
108 |
auto_score: bool = True,
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
113 |
|
114 |
langchain_mode: str = 'Disabled',
|
115 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
|
|
116 |
user_path: str = None,
|
|
|
117 |
load_db_if_exists: bool = True,
|
118 |
keep_sources_in_context: bool = False,
|
119 |
db_type: str = 'chroma',
|
@@ -127,7 +137,7 @@ def main(
|
|
127 |
enable_sources_list: bool = True,
|
128 |
chunk: bool = True,
|
129 |
chunk_size: int = 512,
|
130 |
-
|
131 |
n_jobs: int = -1,
|
132 |
enable_captions: bool = True,
|
133 |
captions_model: str = "Salesforce/blip-image-captioning-base",
|
@@ -138,12 +148,14 @@ def main(
|
|
138 |
"""
|
139 |
|
140 |
:param load_8bit: load model in 8-bit using bitsandbytes
|
|
|
141 |
:param load_half: load model in float16
|
142 |
:param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
|
143 |
-
:param base_model: model HF-type name
|
144 |
-
:param tokenizer_base_model: tokenizer HF-type name
|
145 |
:param lora_weights: LORA weights path/HF link
|
146 |
:param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
|
|
|
147 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
148 |
:param temperature: generation temperature
|
149 |
:param top_p: generation top_p
|
@@ -156,6 +168,7 @@ def main(
|
|
156 |
:param min_new_tokens: generation min tokens
|
157 |
:param early_stopping: generation early stopping
|
158 |
:param max_time: maximum time to allow for generation
|
|
|
159 |
:param debug: enable debug mode
|
160 |
:param save_dir: directory chat data is saved to
|
161 |
:param share: whether to share the gradio app with sharable URL
|
@@ -166,8 +179,16 @@ def main(
|
|
166 |
:param offload_folder: path for spilling model onto disk
|
167 |
:param src_lang: source languages to include if doing translation (None = all)
|
168 |
:param tgt_lang: target languages to include if doing translation (None = all)
|
|
|
|
|
169 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
170 |
:param gradio_avoid_processing_markdown:
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
:param chat: whether to enable chat mode with chat history
|
172 |
:param chat_context: whether to use extra helpful context if human_bot
|
173 |
:param stream_output: whether to stream output from generate
|
@@ -190,32 +211,37 @@ def main(
|
|
190 |
:param extra_lora_options: extra LORA to show in list in gradio
|
191 |
:param score_model: which model to score responses (None means no scoring)
|
192 |
:param auto_score: whether to automatically score responses
|
193 |
-
:param
|
194 |
-
:param
|
195 |
-
:param
|
|
|
196 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
197 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
198 |
-
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode
|
|
|
|
|
|
|
199 |
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
200 |
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
201 |
But wiki_full is expensive and requires preparation
|
202 |
To allow scratch space only live in session, add 'MyData' to list
|
203 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
204 |
FIXME: Avoid 'All' for now, not implemented
|
|
|
205 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
206 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
207 |
-
:param db_type: 'faiss' for in-memory or 'chroma' for persisted on disk
|
208 |
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
209 |
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
210 |
:param hf_embedding_model: Which HF embedding model to use for vector db
|
211 |
:param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
|
212 |
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
213 |
:param enable_url_upload: Whether to allow upload from URL
|
214 |
-
:param enable_text_upload: Whether to allow
|
215 |
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
216 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
217 |
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
|
218 |
-
:param
|
219 |
:param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
|
220 |
:param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
|
221 |
:param captions_model: Which model to use for captions.
|
@@ -233,7 +259,10 @@ def main(
|
|
233 |
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
234 |
is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
|
235 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
236 |
-
|
|
|
|
|
|
|
237 |
admin_pass = os.getenv("ADMIN_PASS")
|
238 |
# will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
|
239 |
# but becomes unrecoverable sometimes if raise, so just be silent for now
|
@@ -265,21 +294,23 @@ def main(
|
|
265 |
# by default don't sample, too chatty
|
266 |
do_sample = False if do_sample is None else do_sample
|
267 |
|
268 |
-
if
|
269 |
if not base_model:
|
270 |
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
271 |
# don't set load_8bit if passed base_model, doesn't always work so can't just override
|
272 |
load_8bit = True
|
|
|
273 |
else:
|
274 |
base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
|
275 |
-
if
|
276 |
load_8bit = True
|
|
|
277 |
if is_hf:
|
278 |
# must override share if in spaces
|
279 |
share = False
|
280 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
281 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
282 |
-
if score_model == 'None':
|
283 |
score_model = ''
|
284 |
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
|
285 |
api_open = bool(int(os.getenv('API_OPEN', api_open)))
|
@@ -289,6 +320,7 @@ def main(
|
|
289 |
if n_gpus == 0:
|
290 |
gpu_id = None
|
291 |
load_8bit = False
|
|
|
292 |
load_half = False
|
293 |
infer_devices = False
|
294 |
torch.backends.cudnn.benchmark = True
|
@@ -328,12 +360,15 @@ def main(
|
|
328 |
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
329 |
repetition_penalty, num_return_sequences,
|
330 |
do_sample,
|
|
|
|
|
331 |
)
|
332 |
|
333 |
locals_dict = locals()
|
334 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
335 |
-
|
336 |
-
|
|
|
337 |
|
338 |
if langchain_mode != "Disabled":
|
339 |
# SECOND PLACE where LangChain referenced, but all imports are kept local so not required
|
@@ -353,7 +388,9 @@ def main(
|
|
353 |
# FIXME: All should be avoided until scans over each db, shouldn't be separate db
|
354 |
continue
|
355 |
persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
|
356 |
-
db = prep_langchain(persist_directory1,
|
|
|
|
|
357 |
langchain_mode1, user_path,
|
358 |
hf_embedding_model,
|
359 |
kwargs_make_db=locals())
|
@@ -367,174 +404,30 @@ def main(
|
|
367 |
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
368 |
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
369 |
|
370 |
-
if
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
import json
|
378 |
-
data = json.load(open(eval_filename, 'rt'))
|
379 |
-
# focus on data that starts with human, else likely chopped from other data
|
380 |
-
turn_start = 0 # odd in general
|
381 |
-
data = [x for x in data if len(x['conversations']) > turn_start + 1 and
|
382 |
-
x['conversations'][turn_start]['from'] == 'human' and
|
383 |
-
x['conversations'][turn_start + 1]['from'] == 'gpt']
|
384 |
-
np.random.seed(eval_sharegpt_prompts_only_seed)
|
385 |
-
example1 = examples[-1] # pick reference example
|
386 |
-
examples = []
|
387 |
-
responses = []
|
388 |
-
for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
|
389 |
-
assert data[i]['conversations'][turn_start]['from'] == 'human'
|
390 |
-
instruction = data[i]['conversations'][turn_start]['value']
|
391 |
-
assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
|
392 |
-
output = data[i]['conversations'][turn_start + 1]['value']
|
393 |
-
examplenew = example1.copy()
|
394 |
-
assert not chat, "No gradio must use chat=False, uses nochat instruct"
|
395 |
-
examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
|
396 |
-
examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
|
397 |
-
examplenew[eval_func_param_names.index('context')] = get_context(chat_context, prompt_type)
|
398 |
-
examples.append(examplenew)
|
399 |
-
responses.append(output)
|
400 |
-
|
401 |
-
num_examples = len(examples)
|
402 |
-
scoring_path = 'scoring'
|
403 |
-
os.makedirs(scoring_path, exist_ok=True)
|
404 |
-
if eval_sharegpt_as_output:
|
405 |
-
used_base_model = 'gpt35'
|
406 |
-
used_lora_weights = ''
|
407 |
-
else:
|
408 |
-
used_base_model = str(base_model.split('/')[-1])
|
409 |
-
used_lora_weights = str(lora_weights.split('/')[-1])
|
410 |
-
eval_filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
|
411 |
-
eval_sharegpt_prompts_only_seed,
|
412 |
-
eval_sharegpt_as_output,
|
413 |
-
used_base_model,
|
414 |
-
used_lora_weights)
|
415 |
-
eval_filename = os.path.join(scoring_path, eval_filename)
|
416 |
-
|
417 |
-
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
|
418 |
-
device = 'cpu' if n_gpus == 0 else 'cuda'
|
419 |
-
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
|
420 |
-
|
421 |
-
with context_class(device):
|
422 |
-
# ensure was set right above before examples generated
|
423 |
-
assert not stream_output, "stream_output=True does not make sense with example loop"
|
424 |
-
import time
|
425 |
-
from functools import partial
|
426 |
-
|
427 |
-
# get score model
|
428 |
-
smodel, stokenizer, sdevice = get_score_model(**locals())
|
429 |
-
|
430 |
-
if not eval_sharegpt_as_output:
|
431 |
-
model, tokenizer, device = get_model(**locals())
|
432 |
-
model_state = [model, tokenizer, device, base_model]
|
433 |
-
kwargs_evaluate = {k: v for k, v in locals().items() if k in inputs_kwargs_list}
|
434 |
-
my_db_state = [None]
|
435 |
-
fun = partial(evaluate, model_state, my_db_state, **kwargs_evaluate)
|
436 |
-
else:
|
437 |
-
assert eval_sharegpt_prompts_only > 0
|
438 |
-
|
439 |
-
def get_response(*args, exi=0):
|
440 |
-
# assumes same ordering of examples and responses
|
441 |
-
yield responses[exi]
|
442 |
-
|
443 |
-
fun = get_response
|
444 |
-
t0 = time.time()
|
445 |
-
score_dump = []
|
446 |
-
|
447 |
-
for exi, ex in enumerate(examples):
|
448 |
-
instruction = ex[eval_func_param_names.index('instruction_nochat')]
|
449 |
-
iinput = ex[eval_func_param_names.index('iinput_nochat')]
|
450 |
-
context = ex[eval_func_param_names.index('context')]
|
451 |
-
clear_torch_cache()
|
452 |
-
print("")
|
453 |
-
print("START" + "=" * 100)
|
454 |
-
print("Question: %s %s" % (instruction, ('input=%s' % iinput if iinput else '')))
|
455 |
-
print("-" * 105)
|
456 |
-
# fun yields as generator, so have to iterate over it
|
457 |
-
# Also means likely do NOT want --stream_output=True, else would show all generations
|
458 |
-
gener = fun(*tuple(ex), exi=exi) if eval_sharegpt_as_output else fun(*tuple(ex))
|
459 |
-
for res in gener:
|
460 |
-
print(res)
|
461 |
-
if smodel:
|
462 |
-
score_with_prompt = False
|
463 |
-
if score_with_prompt:
|
464 |
-
data_point = dict(instruction=instruction, input=iinput, context=context)
|
465 |
-
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
466 |
-
prompt = prompter.generate_prompt(data_point)
|
467 |
-
else:
|
468 |
-
# just raw input and output
|
469 |
-
if eval_sharegpt_prompts_only > 0:
|
470 |
-
# only our own examples have this filled at moment
|
471 |
-
assert iinput in [None, ''], iinput # should be no iinput
|
472 |
-
if not (chat_context and prompt_type == 'human_bot'):
|
473 |
-
assert context in [None, ''], context # should be no context
|
474 |
-
prompt = instruction
|
475 |
-
cutoff_len = 768 if is_low_mem else 2048
|
476 |
-
inputs = stokenizer(prompt, res,
|
477 |
-
return_tensors="pt",
|
478 |
-
truncation=True,
|
479 |
-
max_length=cutoff_len)
|
480 |
-
try:
|
481 |
-
score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
|
482 |
-
except torch.cuda.OutOfMemoryError as e:
|
483 |
-
print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
484 |
-
flush=True)
|
485 |
-
traceback.print_exc()
|
486 |
-
score = 0.0
|
487 |
-
clear_torch_cache()
|
488 |
-
except (Exception, RuntimeError) as e:
|
489 |
-
if 'Expected all tensors to be on the same device' in str(e) or \
|
490 |
-
'expected scalar type Half but found Float' in str(e) or \
|
491 |
-
'probability tensor contains either' in str(e) or \
|
492 |
-
'cublasLt ran into an error!' in str(e):
|
493 |
-
print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
494 |
-
flush=True)
|
495 |
-
traceback.print_exc()
|
496 |
-
score = 0.0
|
497 |
-
clear_torch_cache()
|
498 |
-
else:
|
499 |
-
raise
|
500 |
-
print("SCORE %s: %s" % (exi, score), flush=True)
|
501 |
-
score_dump.append(ex + [prompt, res, score])
|
502 |
-
# dump every score in case abort
|
503 |
-
df_scores = pd.DataFrame(score_dump,
|
504 |
-
columns=eval_func_param_names + eval_extra_columns)
|
505 |
-
df_scores.to_parquet(eval_filename, index=False)
|
506 |
-
# plot histogram so far
|
507 |
-
plt.figure(figsize=(10, 10))
|
508 |
-
plt.hist(df_scores['score'], bins=20)
|
509 |
-
score_avg = np.mean(df_scores['score'])
|
510 |
-
score_median = np.median(df_scores['score'])
|
511 |
-
plt.title("Score avg: %s median: %s" % (score_avg, score_median))
|
512 |
-
plt.savefig(eval_filename.replace('.parquet', '.png'))
|
513 |
-
plt.close()
|
514 |
-
|
515 |
-
print("END" + "=" * 102)
|
516 |
-
print("")
|
517 |
-
t2 = time.time()
|
518 |
-
print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
|
519 |
-
t1 = time.time()
|
520 |
-
print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
|
521 |
-
return eval_filename
|
522 |
-
|
523 |
-
if gradio:
|
524 |
# imported here so don't require gradio to run generate
|
525 |
from gradio_runner import go_gradio
|
526 |
|
527 |
# get default model
|
528 |
all_kwargs = locals().copy()
|
529 |
if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
|
530 |
-
model0, tokenizer0, device = get_model(
|
|
|
531 |
else:
|
532 |
# if empty model, then don't load anything, just get gradio up
|
533 |
model0, tokenizer0, device = None, None, None
|
534 |
model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
|
535 |
|
536 |
# get score model
|
537 |
-
smodel, stokenizer, sdevice = get_score_model(
|
|
|
|
|
538 |
score_model_state0 = [smodel, stokenizer, sdevice, score_model]
|
539 |
|
540 |
if enable_captions:
|
@@ -546,6 +439,7 @@ def main(
|
|
546 |
else:
|
547 |
caption_loader = False
|
548 |
|
|
|
549 |
go_gradio(**locals())
|
550 |
|
551 |
|
@@ -624,12 +518,15 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
624 |
else:
|
625 |
device_map = {'': 'cpu'}
|
626 |
model_kwargs['load_in_8bit'] = False
|
|
|
627 |
print('device_map: %s' % device_map, flush=True)
|
628 |
|
629 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
|
|
630 |
model_kwargs['device_map'] = device_map
|
|
|
631 |
|
632 |
-
if load_in_8bit or not load_half:
|
633 |
model = model_loader.from_pretrained(
|
634 |
base_model,
|
635 |
config=config,
|
@@ -646,6 +543,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
646 |
|
647 |
def get_model(
|
648 |
load_8bit: bool = False,
|
|
|
649 |
load_half: bool = True,
|
650 |
infer_devices: bool = True,
|
651 |
base_model: str = '',
|
@@ -659,12 +557,14 @@ def get_model(
|
|
659 |
use_auth_token: Union[str, bool] = False,
|
660 |
trust_remote_code: bool = True,
|
661 |
offload_folder: str = None,
|
662 |
-
|
663 |
-
|
|
|
664 |
):
|
665 |
"""
|
666 |
|
667 |
:param load_8bit: load model in 8-bit, not supported by all models
|
|
|
668 |
:param load_half: load model in 16-bit
|
669 |
:param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
|
670 |
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
@@ -679,26 +579,29 @@ def get_model(
|
|
679 |
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
680 |
:param trust_remote_code: trust code needed by model
|
681 |
:param offload_folder: offload folder
|
682 |
-
:param
|
683 |
-
:param
|
684 |
:return:
|
685 |
"""
|
686 |
-
|
687 |
-
|
|
|
688 |
from gpt4all_llm import get_model_tokenizer_gpt4all
|
689 |
model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
|
690 |
return model, tokenizer, device
|
691 |
|
692 |
if lora_weights is not None and lora_weights.strip():
|
693 |
-
|
|
|
694 |
device = get_device()
|
695 |
|
696 |
if 'gpt2' in base_model.lower():
|
697 |
# RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
|
698 |
load_8bit = False
|
|
|
699 |
|
700 |
assert base_model.strip(), (
|
701 |
-
"Please choose a base model with --base_model (CLI) or
|
702 |
)
|
703 |
|
704 |
from transformers import AutoConfig
|
@@ -709,8 +612,9 @@ def get_model(
|
|
709 |
llama_type_from_name = "llama" in base_model.lower()
|
710 |
llama_type = llama_type_from_config or llama_type_from_name
|
711 |
if llama_type:
|
712 |
-
|
713 |
-
|
|
|
714 |
|
715 |
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
|
716 |
if not tokenizer_base_model:
|
@@ -744,7 +648,8 @@ def get_model(
|
|
744 |
)
|
745 |
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
|
746 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
747 |
-
|
|
|
748 |
))
|
749 |
if 'mpt-' in base_model.lower() and gpu_id >= 0:
|
750 |
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
|
@@ -753,6 +658,7 @@ def get_model(
|
|
753 |
# FIXME: could put on other GPUs
|
754 |
model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
|
755 |
model_kwargs.pop('torch_dtype', None)
|
|
|
756 |
|
757 |
if not lora_weights:
|
758 |
with torch.device(device):
|
@@ -764,7 +670,7 @@ def get_model(
|
|
764 |
offload_folder=offload_folder,
|
765 |
)
|
766 |
else:
|
767 |
-
if load_half and not load_8bit:
|
768 |
model = model_loader.from_pretrained(
|
769 |
base_model,
|
770 |
**model_kwargs).half()
|
@@ -772,7 +678,7 @@ def get_model(
|
|
772 |
model = model_loader.from_pretrained(
|
773 |
base_model,
|
774 |
**model_kwargs)
|
775 |
-
elif load_8bit:
|
776 |
model = model_loader.from_pretrained(
|
777 |
base_model,
|
778 |
**model_kwargs
|
@@ -821,24 +727,62 @@ def get_model(
|
|
821 |
|
822 |
if not isinstance(tokenizer, str):
|
823 |
model.eval()
|
824 |
-
if torch.__version__ >= "2" and sys.platform != "win32" and
|
825 |
model = torch.compile(model)
|
826 |
|
|
|
|
|
|
|
|
|
|
|
|
|
827 |
return model, tokenizer, device
|
828 |
|
829 |
|
830 |
-
def
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
else:
|
843 |
smodel, stokenizer, sdevice = None, None, None
|
844 |
return smodel, stokenizer, sdevice
|
@@ -864,6 +808,7 @@ eval_func_param_names = ['instruction',
|
|
864 |
'instruction_nochat',
|
865 |
'iinput_nochat',
|
866 |
'langchain_mode',
|
|
|
867 |
'document_choice',
|
868 |
]
|
869 |
|
@@ -892,6 +837,7 @@ def evaluate(
|
|
892 |
instruction_nochat,
|
893 |
iinput_nochat,
|
894 |
langchain_mode,
|
|
|
895 |
document_choice,
|
896 |
# END NOTE: Examples must have same order of parameters
|
897 |
src_lang=None,
|
@@ -901,27 +847,29 @@ def evaluate(
|
|
901 |
save_dir=None,
|
902 |
sanitize_bot_response=True,
|
903 |
model_state0=None,
|
904 |
-
|
905 |
raise_generate_gpu_exceptions=None,
|
906 |
chat_context=None,
|
907 |
lora_weights=None,
|
908 |
load_db_if_exists=True,
|
909 |
dbs=None,
|
910 |
user_path=None,
|
|
|
911 |
use_openai_embedding=None,
|
912 |
use_openai_model=None,
|
913 |
hf_embedding_model=None,
|
914 |
chunk=None,
|
915 |
chunk_size=None,
|
916 |
db_type=None,
|
917 |
-
k=None,
|
918 |
n_jobs=None,
|
919 |
first_para=None,
|
920 |
text_limit=None,
|
|
|
|
|
921 |
):
|
922 |
# ensure passed these
|
923 |
assert concurrency_count is not None
|
924 |
-
assert
|
925 |
assert raise_generate_gpu_exceptions is not None
|
926 |
assert chat_context is not None
|
927 |
assert use_openai_embedding is not None
|
@@ -930,7 +878,7 @@ def evaluate(
|
|
930 |
assert chunk is not None
|
931 |
assert chunk_size is not None
|
932 |
assert db_type is not None
|
933 |
-
assert
|
934 |
assert n_jobs is not None
|
935 |
assert first_para is not None
|
936 |
|
@@ -940,7 +888,7 @@ def evaluate(
|
|
940 |
locals_dict.pop('model_state0', None)
|
941 |
print(locals_dict)
|
942 |
|
943 |
-
no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
|
944 |
|
945 |
if model_state0 is None:
|
946 |
# e.g. for no gradio case, set dummy value, else should be set
|
@@ -990,7 +938,7 @@ def evaluate(
|
|
990 |
db1 = dbs[langchain_mode]
|
991 |
else:
|
992 |
db1 = None
|
993 |
-
if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in
|
994 |
query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
|
995 |
outr = ""
|
996 |
# use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
|
@@ -1002,6 +950,7 @@ def evaluate(
|
|
1002 |
load_db_if_exists=load_db_if_exists,
|
1003 |
db=db1,
|
1004 |
user_path=user_path,
|
|
|
1005 |
max_new_tokens=max_new_tokens,
|
1006 |
cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
|
1007 |
use_openai_embedding=use_openai_embedding,
|
@@ -1014,21 +963,28 @@ def evaluate(
|
|
1014 |
langchain_mode=langchain_mode,
|
1015 |
document_choice=document_choice,
|
1016 |
db_type=db_type,
|
1017 |
-
k=
|
1018 |
temperature=temperature,
|
1019 |
repetition_penalty=repetition_penalty,
|
1020 |
top_k=top_k,
|
1021 |
top_p=top_p,
|
1022 |
prompt_type=prompt_type,
|
1023 |
n_jobs=n_jobs,
|
|
|
|
|
1024 |
):
|
1025 |
-
outr = r # doesn't accumulate, new answer every yield, so only save that full answer
|
1026 |
-
yield
|
1027 |
if save_dir:
|
1028 |
save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
|
|
|
|
|
|
|
|
|
|
1032 |
return
|
1033 |
|
1034 |
if isinstance(tokenizer, str):
|
@@ -1038,7 +994,7 @@ def evaluate(
|
|
1038 |
else:
|
1039 |
raise RuntimeError("No such task type %s" % tokenizer)
|
1040 |
# NOTE: uses max_length only
|
1041 |
-
yield model(prompt, max_length=max_new_tokens)[0][key]
|
1042 |
|
1043 |
if 'mbart-' in base_model.lower():
|
1044 |
assert src_lang is not None
|
@@ -1048,7 +1004,7 @@ def evaluate(
|
|
1048 |
# override, ignore user change
|
1049 |
num_return_sequences = 1
|
1050 |
stopping_criteria = get_stopping(prompt_type, tokenizer, device)
|
1051 |
-
_, _, max_length_tokenize, max_prompt_length = get_cutoffs(
|
1052 |
prompt = prompt[-max_prompt_length:]
|
1053 |
inputs = tokenizer(prompt,
|
1054 |
return_tensors="pt",
|
@@ -1059,6 +1015,10 @@ def evaluate(
|
|
1059 |
if debug and len(inputs["input_ids"]) > 0:
|
1060 |
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
1061 |
input_ids = inputs["input_ids"].to(device)
|
|
|
|
|
|
|
|
|
1062 |
generation_config = GenerationConfig(
|
1063 |
temperature=float(temperature),
|
1064 |
top_p=float(top_p),
|
@@ -1111,10 +1071,12 @@ def evaluate(
|
|
1111 |
# https://github.com/h2oai/h2ogpt/issues/104
|
1112 |
# but only makes sense if concurrency_count == 1
|
1113 |
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
|
1114 |
-
|
|
|
1115 |
decoded_output = None
|
1116 |
with context_class("generate.lock"):
|
1117 |
-
|
|
|
1118 |
# decoded tokenized prompt can deviate from prompt due to special characters
|
1119 |
inputs_decoded = decoder(input_ids[0])
|
1120 |
inputs_decoded_raw = decoder_raw(input_ids[0])
|
@@ -1136,7 +1098,8 @@ def evaluate(
|
|
1136 |
decoder = decoder_raw
|
1137 |
decoder_kwargs = decoder_raw_kwargs
|
1138 |
else:
|
1139 |
-
|
|
|
1140 |
if stream_output:
|
1141 |
skip_prompt = False
|
1142 |
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
|
@@ -1155,8 +1118,9 @@ def evaluate(
|
|
1155 |
if bucket.qsize() > 0 or thread.exc:
|
1156 |
thread.join()
|
1157 |
outputs += new_text
|
1158 |
-
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1159 |
-
|
|
|
1160 |
except BaseException:
|
1161 |
# if any exception, raise that exception if was from thread, first
|
1162 |
if thread.exc:
|
@@ -1173,14 +1137,15 @@ def evaluate(
|
|
1173 |
else:
|
1174 |
outputs = model.generate(**gen_kwargs)
|
1175 |
outputs = [decoder(s) for s in outputs.sequences]
|
1176 |
-
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
1177 |
-
|
1178 |
if outputs and len(outputs) >= 1:
|
1179 |
decoded_output = prompt + outputs[0]
|
1180 |
if save_dir and decoded_output:
|
1181 |
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1182 |
-
|
1183 |
-
|
|
|
1184 |
|
1185 |
|
1186 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
@@ -1188,12 +1153,15 @@ state_names = ['model_state', 'my_db_state']
|
|
1188 |
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
|
1189 |
|
1190 |
|
1191 |
-
def get_cutoffs(
|
1192 |
# help to avoid errors like:
|
1193 |
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
1194 |
# RuntimeError: expected scalar type Half but found Float
|
1195 |
# with - 256
|
1196 |
-
|
|
|
|
|
|
|
1197 |
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1198 |
output_smallest = 30 * 4
|
1199 |
max_prompt_length = cutoff_len - output_smallest
|
@@ -1286,7 +1254,7 @@ def get_generate_params(model_lower, chat,
|
|
1286 |
prompt_type, temperature, top_p, top_k, num_beams,
|
1287 |
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
1288 |
repetition_penalty, num_return_sequences,
|
1289 |
-
do_sample):
|
1290 |
use_defaults = False
|
1291 |
use_default_examples = True
|
1292 |
examples = []
|
@@ -1303,7 +1271,8 @@ def get_generate_params(model_lower, chat,
|
|
1303 |
|
1304 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
1305 |
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
1306 |
-
|
|
|
1307 |
|
1308 |
# examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
|
1309 |
if show_examples is None:
|
@@ -1366,9 +1335,6 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
|
|
1366 |
prompt_type = prompt_type or 'plain'
|
1367 |
else:
|
1368 |
prompt_type = ''
|
1369 |
-
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
1370 |
-
stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
|
1371 |
-
False]]
|
1372 |
task_info = "No task"
|
1373 |
if prompt_type == 'instruct':
|
1374 |
task_info = "Answer question or follow imperative as instruction with optionally input."
|
@@ -1443,13 +1409,15 @@ y = np.random.randint(0, 1, 100)
|
|
1443 |
|
1444 |
# fit random forest classifier with 20 estimators""", ''] + params_list,
|
1445 |
]
|
|
|
|
|
1446 |
|
1447 |
src_lang = "English"
|
1448 |
tgt_lang = "Russian"
|
1449 |
|
1450 |
# move to correct position
|
1451 |
for example in examples:
|
1452 |
-
example += [chat, '', '', 'Disabled', ['All']]
|
1453 |
# adjust examples if non-chat mode
|
1454 |
if not chat:
|
1455 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
@@ -1521,6 +1489,32 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
|
|
1521 |
return score
|
1522 |
|
1523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1524 |
if __name__ == "__main__":
|
1525 |
"""
|
1526 |
Examples:
|
|
|
9 |
import time
|
10 |
import traceback
|
11 |
import typing
|
12 |
+
import warnings
|
13 |
from datetime import datetime
|
14 |
import filelock
|
15 |
import psutil
|
16 |
|
17 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
18 |
+
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
19 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
20 |
+
|
21 |
from loaders import get_loaders
|
22 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
23 |
+
import_matplotlib, get_device, makedirs, get_kwargs
|
24 |
|
25 |
import_matplotlib()
|
|
|
26 |
|
27 |
SEED = 1236
|
28 |
set_seed(SEED)
|
29 |
|
|
|
30 |
from typing import Union
|
|
|
|
|
31 |
|
32 |
import fire
|
33 |
import torch
|
|
|
35 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
36 |
from accelerate import init_empty_weights, infer_auto_device_map
|
37 |
|
38 |
+
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types
|
39 |
from stopping import get_stopping
|
40 |
|
41 |
eval_extra_columns = ['prompt', 'response', 'score']
|
|
|
48 |
|
49 |
def main(
|
50 |
load_8bit: bool = False,
|
51 |
+
load_4bit: bool = False,
|
52 |
load_half: bool = True,
|
53 |
infer_devices: bool = True,
|
54 |
base_model: str = '',
|
55 |
tokenizer_base_model: str = '',
|
56 |
lora_weights: str = "",
|
57 |
gpu_id: int = 0,
|
58 |
+
compile_model: bool = True,
|
59 |
|
60 |
prompt_type: Union[int, str] = None,
|
61 |
# input to generation
|
|
|
71 |
early_stopping: Union[bool, str] = None,
|
72 |
max_time: float = None,
|
73 |
|
74 |
+
memory_restriction_level: int = None,
|
75 |
debug: bool = False,
|
76 |
save_dir: str = None,
|
77 |
share: bool = True,
|
|
|
84 |
src_lang: str = "English",
|
85 |
tgt_lang: str = "Russian",
|
86 |
|
87 |
+
cli: bool = False,
|
88 |
+
cli_loop: bool = True,
|
89 |
gradio: bool = True,
|
90 |
gradio_avoid_processing_markdown: bool = False,
|
91 |
+
gradio_offline_level: int = 0,
|
92 |
chat: bool = True,
|
93 |
chat_context: bool = False,
|
94 |
stream_output: bool = True,
|
95 |
show_examples: bool = None,
|
96 |
verbose: bool = False,
|
97 |
+
h2ocolors: bool = False,
|
98 |
+
height: int = 600,
|
99 |
show_lora: bool = True,
|
100 |
login_mode_if_model0: bool = False,
|
101 |
block_gradio_exit: bool = True,
|
|
|
114 |
score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
|
115 |
auto_score: bool = True,
|
116 |
|
117 |
+
eval_filename: str = None,
|
118 |
+
eval_prompts_only_num: int = 0,
|
119 |
+
eval_prompts_only_seed: int = 1234,
|
120 |
+
eval_as_output: bool = False,
|
121 |
|
122 |
langchain_mode: str = 'Disabled',
|
123 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
124 |
+
document_choice: list = ['All'],
|
125 |
user_path: str = None,
|
126 |
+
detect_user_path_changes_every_query: bool = False,
|
127 |
load_db_if_exists: bool = True,
|
128 |
keep_sources_in_context: bool = False,
|
129 |
db_type: str = 'chroma',
|
|
|
137 |
enable_sources_list: bool = True,
|
138 |
chunk: bool = True,
|
139 |
chunk_size: int = 512,
|
140 |
+
top_k_docs: int = 3, # FIXME: Can go back to 4 once https://github.com/h2oai/h2ogpt/issues/192 fixed
|
141 |
n_jobs: int = -1,
|
142 |
enable_captions: bool = True,
|
143 |
captions_model: str = "Salesforce/blip-image-captioning-base",
|
|
|
148 |
"""
|
149 |
|
150 |
:param load_8bit: load model in 8-bit using bitsandbytes
|
151 |
+
:param load_4bit: load model in 4-bit using bitsandbytes
|
152 |
:param load_half: load model in float16
|
153 |
:param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
|
154 |
+
:param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
|
155 |
+
:param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
|
156 |
:param lora_weights: LORA weights path/HF link
|
157 |
:param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
|
158 |
+
:param compile_model Whether to compile the model
|
159 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
160 |
:param temperature: generation temperature
|
161 |
:param top_p: generation top_p
|
|
|
168 |
:param min_new_tokens: generation min tokens
|
169 |
:param early_stopping: generation early stopping
|
170 |
:param max_time: maximum time to allow for generation
|
171 |
+
:param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case
|
172 |
:param debug: enable debug mode
|
173 |
:param save_dir: directory chat data is saved to
|
174 |
:param share: whether to share the gradio app with sharable URL
|
|
|
179 |
:param offload_folder: path for spilling model onto disk
|
180 |
:param src_lang: source languages to include if doing translation (None = all)
|
181 |
:param tgt_lang: target languages to include if doing translation (None = all)
|
182 |
+
:param cli: whether to use CLI (non-gradio) interface.
|
183 |
+
:param cli_loop: whether to loop for CLI (False usually only for testing)
|
184 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
185 |
:param gradio_avoid_processing_markdown:
|
186 |
+
:param gradio_offline_level: > 0, then change fonts so full offline
|
187 |
+
== 1 means backend won't need internet for fonts, but front-end UI might if font not cached
|
188 |
+
== 2 means backend and frontend don't need internet to download any fonts.
|
189 |
+
Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
|
190 |
+
This option further disables google fonts for downloading, which is less intrusive than uploading,
|
191 |
+
but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
|
192 |
:param chat: whether to enable chat mode with chat history
|
193 |
:param chat_context: whether to use extra helpful context if human_bot
|
194 |
:param stream_output: whether to stream output from generate
|
|
|
211 |
:param extra_lora_options: extra LORA to show in list in gradio
|
212 |
:param score_model: which model to score responses (None means no scoring)
|
213 |
:param auto_score: whether to automatically score responses
|
214 |
+
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
215 |
+
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
216 |
+
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
217 |
+
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
218 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
219 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
220 |
+
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
221 |
+
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
222 |
+
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
223 |
+
Expensive for large number of files, so not done by default. By default only detect changes during db loading.
|
224 |
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
225 |
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
226 |
But wiki_full is expensive and requires preparation
|
227 |
To allow scratch space only live in session, add 'MyData' to list
|
228 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
229 |
FIXME: Avoid 'All' for now, not implemented
|
230 |
+
:param document_choice: Default document choice when taking subset of collection
|
231 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
232 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
233 |
+
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
234 |
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
235 |
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
236 |
:param hf_embedding_model: Which HF embedding model to use for vector db
|
237 |
:param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
|
238 |
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
239 |
:param enable_url_upload: Whether to allow upload from URL
|
240 |
+
:param enable_text_upload: Whether to allow upload of text
|
241 |
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
242 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
243 |
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
|
244 |
+
:param top_k_docs: number of chunks to give LLM
|
245 |
:param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
|
246 |
:param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
|
247 |
:param captions_model: Which model to use for captions.
|
|
|
259 |
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
260 |
is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
|
261 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
262 |
+
if memory_restriction_level is None:
|
263 |
+
memory_restriction_level = 2 if is_hf else 0 # 2 assumes run on 24GB consumer GPU
|
264 |
+
else:
|
265 |
+
assert 0 <= memory_restriction_level <= 3, "Bad memory_restriction_level=%s" % memory_restriction_level
|
266 |
admin_pass = os.getenv("ADMIN_PASS")
|
267 |
# will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
|
268 |
# but becomes unrecoverable sometimes if raise, so just be silent for now
|
|
|
294 |
# by default don't sample, too chatty
|
295 |
do_sample = False if do_sample is None else do_sample
|
296 |
|
297 |
+
if memory_restriction_level == 2:
|
298 |
if not base_model:
|
299 |
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
300 |
# don't set load_8bit if passed base_model, doesn't always work so can't just override
|
301 |
load_8bit = True
|
302 |
+
load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
|
303 |
else:
|
304 |
base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
|
305 |
+
if memory_restriction_level >= 2:
|
306 |
load_8bit = True
|
307 |
+
load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
|
308 |
if is_hf:
|
309 |
# must override share if in spaces
|
310 |
share = False
|
311 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
312 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
313 |
+
if score_model == 'None' or score_model is None:
|
314 |
score_model = ''
|
315 |
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
|
316 |
api_open = bool(int(os.getenv('API_OPEN', api_open)))
|
|
|
320 |
if n_gpus == 0:
|
321 |
gpu_id = None
|
322 |
load_8bit = False
|
323 |
+
load_4bit = False
|
324 |
load_half = False
|
325 |
infer_devices = False
|
326 |
torch.backends.cudnn.benchmark = True
|
|
|
360 |
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
361 |
repetition_penalty, num_return_sequences,
|
362 |
do_sample,
|
363 |
+
top_k_docs,
|
364 |
+
verbose,
|
365 |
)
|
366 |
|
367 |
locals_dict = locals()
|
368 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
369 |
+
if verbose:
|
370 |
+
print(f"Generating model with params:\n{locals_print}", flush=True)
|
371 |
+
print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
|
372 |
|
373 |
if langchain_mode != "Disabled":
|
374 |
# SECOND PLACE where LangChain referenced, but all imports are kept local so not required
|
|
|
388 |
# FIXME: All should be avoided until scans over each db, shouldn't be separate db
|
389 |
continue
|
390 |
persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
|
391 |
+
db = prep_langchain(persist_directory1,
|
392 |
+
load_db_if_exists,
|
393 |
+
db_type, use_openai_embedding,
|
394 |
langchain_mode1, user_path,
|
395 |
hf_embedding_model,
|
396 |
kwargs_make_db=locals())
|
|
|
404 |
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
405 |
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
406 |
|
407 |
+
if cli:
|
408 |
+
from cli import run_cli
|
409 |
+
return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
|
410 |
+
elif not gradio:
|
411 |
+
from eval import run_eval
|
412 |
+
return run_eval(**get_kwargs(run_eval, exclude_names=['model_state0'], **locals()))
|
413 |
+
elif gradio:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
# imported here so don't require gradio to run generate
|
415 |
from gradio_runner import go_gradio
|
416 |
|
417 |
# get default model
|
418 |
all_kwargs = locals().copy()
|
419 |
if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
|
420 |
+
model0, tokenizer0, device = get_model(reward_type=False,
|
421 |
+
**get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs))
|
422 |
else:
|
423 |
# if empty model, then don't load anything, just get gradio up
|
424 |
model0, tokenizer0, device = None, None, None
|
425 |
model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
|
426 |
|
427 |
# get score model
|
428 |
+
smodel, stokenizer, sdevice = get_score_model(reward_type=True,
|
429 |
+
**get_kwargs(get_score_model, exclude_names=['reward_type'],
|
430 |
+
**all_kwargs))
|
431 |
score_model_state0 = [smodel, stokenizer, sdevice, score_model]
|
432 |
|
433 |
if enable_captions:
|
|
|
439 |
else:
|
440 |
caption_loader = False
|
441 |
|
442 |
+
# assume gradio needs everything
|
443 |
go_gradio(**locals())
|
444 |
|
445 |
|
|
|
518 |
else:
|
519 |
device_map = {'': 'cpu'}
|
520 |
model_kwargs['load_in_8bit'] = False
|
521 |
+
model_kwargs['load_in_4bit'] = False
|
522 |
print('device_map: %s' % device_map, flush=True)
|
523 |
|
524 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
525 |
+
load_in_4bit = model_kwargs.get('load_in_4bit', False)
|
526 |
model_kwargs['device_map'] = device_map
|
527 |
+
pop_unused_model_kwargs(model_kwargs)
|
528 |
|
529 |
+
if load_in_8bit or load_in_4bit or not load_half:
|
530 |
model = model_loader.from_pretrained(
|
531 |
base_model,
|
532 |
config=config,
|
|
|
543 |
|
544 |
def get_model(
|
545 |
load_8bit: bool = False,
|
546 |
+
load_4bit: bool = False,
|
547 |
load_half: bool = True,
|
548 |
infer_devices: bool = True,
|
549 |
base_model: str = '',
|
|
|
557 |
use_auth_token: Union[str, bool] = False,
|
558 |
trust_remote_code: bool = True,
|
559 |
offload_folder: str = None,
|
560 |
+
compile_model: bool = True,
|
561 |
+
|
562 |
+
verbose: bool = False,
|
563 |
):
|
564 |
"""
|
565 |
|
566 |
:param load_8bit: load model in 8-bit, not supported by all models
|
567 |
+
:param load_4bit: load model in 4-bit, not supported by all models
|
568 |
:param load_half: load model in 16-bit
|
569 |
:param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
|
570 |
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
|
|
579 |
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
580 |
:param trust_remote_code: trust code needed by model
|
581 |
:param offload_folder: offload folder
|
582 |
+
:param compile_model: whether to compile torch model
|
583 |
+
:param verbose:
|
584 |
:return:
|
585 |
"""
|
586 |
+
if verbose:
|
587 |
+
print("Get %s model" % base_model, flush=True)
|
588 |
+
if base_model in non_hf_types:
|
589 |
from gpt4all_llm import get_model_tokenizer_gpt4all
|
590 |
model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
|
591 |
return model, tokenizer, device
|
592 |
|
593 |
if lora_weights is not None and lora_weights.strip():
|
594 |
+
if verbose:
|
595 |
+
print("Get %s lora weights" % lora_weights, flush=True)
|
596 |
device = get_device()
|
597 |
|
598 |
if 'gpt2' in base_model.lower():
|
599 |
# RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
|
600 |
load_8bit = False
|
601 |
+
load_4bit = False
|
602 |
|
603 |
assert base_model.strip(), (
|
604 |
+
"Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
|
605 |
)
|
606 |
|
607 |
from transformers import AutoConfig
|
|
|
612 |
llama_type_from_name = "llama" in base_model.lower()
|
613 |
llama_type = llama_type_from_config or llama_type_from_name
|
614 |
if llama_type:
|
615 |
+
if verbose:
|
616 |
+
print("Detected as llama type from"
|
617 |
+
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
|
618 |
|
619 |
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
|
620 |
if not tokenizer_base_model:
|
|
|
648 |
)
|
649 |
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
|
650 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
651 |
+
load_in_4bit=load_4bit,
|
652 |
+
device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
|
653 |
))
|
654 |
if 'mpt-' in base_model.lower() and gpu_id >= 0:
|
655 |
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
|
|
|
658 |
# FIXME: could put on other GPUs
|
659 |
model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
|
660 |
model_kwargs.pop('torch_dtype', None)
|
661 |
+
pop_unused_model_kwargs(model_kwargs)
|
662 |
|
663 |
if not lora_weights:
|
664 |
with torch.device(device):
|
|
|
670 |
offload_folder=offload_folder,
|
671 |
)
|
672 |
else:
|
673 |
+
if load_half and not (load_8bit or load_4bit):
|
674 |
model = model_loader.from_pretrained(
|
675 |
base_model,
|
676 |
**model_kwargs).half()
|
|
|
678 |
model = model_loader.from_pretrained(
|
679 |
base_model,
|
680 |
**model_kwargs)
|
681 |
+
elif load_8bit or load_4bit:
|
682 |
model = model_loader.from_pretrained(
|
683 |
base_model,
|
684 |
**model_kwargs
|
|
|
727 |
|
728 |
if not isinstance(tokenizer, str):
|
729 |
model.eval()
|
730 |
+
if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
|
731 |
model = torch.compile(model)
|
732 |
|
733 |
+
if hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
|
734 |
+
# help automatically limit inputs to generate
|
735 |
+
tokenizer.model_max_length = config.max_position_embeddings
|
736 |
+
else:
|
737 |
+
tokenizer.model_max_length = 2048
|
738 |
+
|
739 |
return model, tokenizer, device
|
740 |
|
741 |
|
742 |
+
def pop_unused_model_kwargs(model_kwargs):
|
743 |
+
"""
|
744 |
+
in-place pop unused kwargs that are not dependency-upgrade friendly
|
745 |
+
no point passing in False, is default, and helps avoid needing to update requirements for new deps
|
746 |
+
:param model_kwargs:
|
747 |
+
:return:
|
748 |
+
"""
|
749 |
+
check_list = ['load_in_8bit', 'load_in_4bit']
|
750 |
+
for k in check_list:
|
751 |
+
if k in model_kwargs and not model_kwargs[k]:
|
752 |
+
model_kwargs.pop(k)
|
753 |
+
|
754 |
+
|
755 |
+
def get_score_model(score_model: str = None,
|
756 |
+
load_8bit: bool = False,
|
757 |
+
load_4bit: bool = False,
|
758 |
+
load_half: bool = True,
|
759 |
+
infer_devices: bool = True,
|
760 |
+
base_model: str = '',
|
761 |
+
tokenizer_base_model: str = '',
|
762 |
+
lora_weights: str = "",
|
763 |
+
gpu_id: int = 0,
|
764 |
+
|
765 |
+
reward_type: bool = None,
|
766 |
+
local_files_only: bool = False,
|
767 |
+
resume_download: bool = True,
|
768 |
+
use_auth_token: Union[str, bool] = False,
|
769 |
+
trust_remote_code: bool = True,
|
770 |
+
offload_folder: str = None,
|
771 |
+
compile_model: bool = True,
|
772 |
+
|
773 |
+
verbose: bool = False,
|
774 |
+
):
|
775 |
+
if score_model is not None and score_model.strip():
|
776 |
+
load_8bit = False
|
777 |
+
load_4bit = False
|
778 |
+
load_half = False
|
779 |
+
base_model = score_model.strip()
|
780 |
+
tokenizer_base_model = ''
|
781 |
+
lora_weights = ''
|
782 |
+
llama_type = False
|
783 |
+
compile_model = False
|
784 |
+
smodel, stokenizer, sdevice = get_model(reward_type=True,
|
785 |
+
**get_kwargs(get_model, exclude_names=['reward_type'], **locals()))
|
786 |
else:
|
787 |
smodel, stokenizer, sdevice = None, None, None
|
788 |
return smodel, stokenizer, sdevice
|
|
|
808 |
'instruction_nochat',
|
809 |
'iinput_nochat',
|
810 |
'langchain_mode',
|
811 |
+
'top_k_docs',
|
812 |
'document_choice',
|
813 |
]
|
814 |
|
|
|
837 |
instruction_nochat,
|
838 |
iinput_nochat,
|
839 |
langchain_mode,
|
840 |
+
top_k_docs,
|
841 |
document_choice,
|
842 |
# END NOTE: Examples must have same order of parameters
|
843 |
src_lang=None,
|
|
|
847 |
save_dir=None,
|
848 |
sanitize_bot_response=True,
|
849 |
model_state0=None,
|
850 |
+
memory_restriction_level=None,
|
851 |
raise_generate_gpu_exceptions=None,
|
852 |
chat_context=None,
|
853 |
lora_weights=None,
|
854 |
load_db_if_exists=True,
|
855 |
dbs=None,
|
856 |
user_path=None,
|
857 |
+
detect_user_path_changes_every_query=None,
|
858 |
use_openai_embedding=None,
|
859 |
use_openai_model=None,
|
860 |
hf_embedding_model=None,
|
861 |
chunk=None,
|
862 |
chunk_size=None,
|
863 |
db_type=None,
|
|
|
864 |
n_jobs=None,
|
865 |
first_para=None,
|
866 |
text_limit=None,
|
867 |
+
verbose=False,
|
868 |
+
cli=False,
|
869 |
):
|
870 |
# ensure passed these
|
871 |
assert concurrency_count is not None
|
872 |
+
assert memory_restriction_level is not None
|
873 |
assert raise_generate_gpu_exceptions is not None
|
874 |
assert chat_context is not None
|
875 |
assert use_openai_embedding is not None
|
|
|
878 |
assert chunk is not None
|
879 |
assert chunk_size is not None
|
880 |
assert db_type is not None
|
881 |
+
assert top_k_docs is not None and isinstance(top_k_docs, int)
|
882 |
assert n_jobs is not None
|
883 |
assert first_para is not None
|
884 |
|
|
|
888 |
locals_dict.pop('model_state0', None)
|
889 |
print(locals_dict)
|
890 |
|
891 |
+
no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\nThen start New Conversation"
|
892 |
|
893 |
if model_state0 is None:
|
894 |
# e.g. for no gradio case, set dummy value, else should be set
|
|
|
938 |
db1 = dbs[langchain_mode]
|
939 |
else:
|
940 |
db1 = None
|
941 |
+
if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in non_hf_types:
|
942 |
query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
|
943 |
outr = ""
|
944 |
# use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
|
|
|
950 |
load_db_if_exists=load_db_if_exists,
|
951 |
db=db1,
|
952 |
user_path=user_path,
|
953 |
+
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
|
954 |
max_new_tokens=max_new_tokens,
|
955 |
cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
|
956 |
use_openai_embedding=use_openai_embedding,
|
|
|
963 |
langchain_mode=langchain_mode,
|
964 |
document_choice=document_choice,
|
965 |
db_type=db_type,
|
966 |
+
k=top_k_docs,
|
967 |
temperature=temperature,
|
968 |
repetition_penalty=repetition_penalty,
|
969 |
top_k=top_k,
|
970 |
top_p=top_p,
|
971 |
prompt_type=prompt_type,
|
972 |
n_jobs=n_jobs,
|
973 |
+
verbose=verbose,
|
974 |
+
cli=cli,
|
975 |
):
|
976 |
+
outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
|
977 |
+
yield dict(response=outr, sources=extra)
|
978 |
if save_dir:
|
979 |
save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
|
980 |
+
if verbose:
|
981 |
+
print(
|
982 |
+
'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
|
983 |
+
flush=True)
|
984 |
+
if outr or base_model in non_hf_types:
|
985 |
+
# if got no response (e.g. not showing sources and got no sources,
|
986 |
+
# so nothing to give to LLM), then slip through and ask LLM
|
987 |
+
# Or if llama/gptj, then just return since they had no response and can't go down below code path
|
988 |
return
|
989 |
|
990 |
if isinstance(tokenizer, str):
|
|
|
994 |
else:
|
995 |
raise RuntimeError("No such task type %s" % tokenizer)
|
996 |
# NOTE: uses max_length only
|
997 |
+
yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources='')
|
998 |
|
999 |
if 'mbart-' in base_model.lower():
|
1000 |
assert src_lang is not None
|
|
|
1004 |
# override, ignore user change
|
1005 |
num_return_sequences = 1
|
1006 |
stopping_criteria = get_stopping(prompt_type, tokenizer, device)
|
1007 |
+
_, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level, model_max_length=tokenizer.model_max_length)
|
1008 |
prompt = prompt[-max_prompt_length:]
|
1009 |
inputs = tokenizer(prompt,
|
1010 |
return_tensors="pt",
|
|
|
1015 |
if debug and len(inputs["input_ids"]) > 0:
|
1016 |
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
1017 |
input_ids = inputs["input_ids"].to(device)
|
1018 |
+
# CRITICAL LIMIT else will fail
|
1019 |
+
max_max_tokens = tokenizer.model_max_length
|
1020 |
+
max_input_tokens = max_max_tokens - max_new_tokens
|
1021 |
+
input_ids = input_ids[:, -max_input_tokens:]
|
1022 |
generation_config = GenerationConfig(
|
1023 |
temperature=float(temperature),
|
1024 |
top_p=float(top_p),
|
|
|
1071 |
# https://github.com/h2oai/h2ogpt/issues/104
|
1072 |
# but only makes sense if concurrency_count == 1
|
1073 |
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
|
1074 |
+
if verbose:
|
1075 |
+
print('Pre-Generate: %s' % str(datetime.now()), flush=True)
|
1076 |
decoded_output = None
|
1077 |
with context_class("generate.lock"):
|
1078 |
+
if verbose:
|
1079 |
+
print('Generate: %s' % str(datetime.now()), flush=True)
|
1080 |
# decoded tokenized prompt can deviate from prompt due to special characters
|
1081 |
inputs_decoded = decoder(input_ids[0])
|
1082 |
inputs_decoded_raw = decoder_raw(input_ids[0])
|
|
|
1098 |
decoder = decoder_raw
|
1099 |
decoder_kwargs = decoder_raw_kwargs
|
1100 |
else:
|
1101 |
+
if verbose:
|
1102 |
+
print("WARNING: Special characters in prompt", flush=True)
|
1103 |
if stream_output:
|
1104 |
skip_prompt = False
|
1105 |
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
|
|
|
1118 |
if bucket.qsize() > 0 or thread.exc:
|
1119 |
thread.join()
|
1120 |
outputs += new_text
|
1121 |
+
yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
|
1122 |
+
sanitize_bot_response=sanitize_bot_response),
|
1123 |
+
sources='')
|
1124 |
except BaseException:
|
1125 |
# if any exception, raise that exception if was from thread, first
|
1126 |
if thread.exc:
|
|
|
1137 |
else:
|
1138 |
outputs = model.generate(**gen_kwargs)
|
1139 |
outputs = [decoder(s) for s in outputs.sequences]
|
1140 |
+
yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
|
1141 |
+
sanitize_bot_response=sanitize_bot_response), sources='')
|
1142 |
if outputs and len(outputs) >= 1:
|
1143 |
decoded_output = prompt + outputs[0]
|
1144 |
if save_dir and decoded_output:
|
1145 |
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1146 |
+
if verbose:
|
1147 |
+
print('Post-Generate: %s decoded_output: %s' % (
|
1148 |
+
str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
|
1149 |
|
1150 |
|
1151 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
|
|
1153 |
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
|
1154 |
|
1155 |
|
1156 |
+
def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048):
|
1157 |
# help to avoid errors like:
|
1158 |
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
1159 |
# RuntimeError: expected scalar type Half but found Float
|
1160 |
# with - 256
|
1161 |
+
if memory_restriction_level > 0:
|
1162 |
+
max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
|
1163 |
+
else:
|
1164 |
+
max_length_tokenize = model_max_length - 256
|
1165 |
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1166 |
output_smallest = 30 * 4
|
1167 |
max_prompt_length = cutoff_len - output_smallest
|
|
|
1254 |
prompt_type, temperature, top_p, top_k, num_beams,
|
1255 |
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
1256 |
repetition_penalty, num_return_sequences,
|
1257 |
+
do_sample, k, verbose):
|
1258 |
use_defaults = False
|
1259 |
use_default_examples = True
|
1260 |
examples = []
|
|
|
1271 |
|
1272 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
1273 |
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
1274 |
+
if verbose:
|
1275 |
+
print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
|
1276 |
|
1277 |
# examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
|
1278 |
if show_examples is None:
|
|
|
1335 |
prompt_type = prompt_type or 'plain'
|
1336 |
else:
|
1337 |
prompt_type = ''
|
|
|
|
|
|
|
1338 |
task_info = "No task"
|
1339 |
if prompt_type == 'instruct':
|
1340 |
task_info = "Answer question or follow imperative as instruction with optionally input."
|
|
|
1409 |
|
1410 |
# fit random forest classifier with 20 estimators""", ''] + params_list,
|
1411 |
]
|
1412 |
+
# add summary example
|
1413 |
+
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list]
|
1414 |
|
1415 |
src_lang = "English"
|
1416 |
tgt_lang = "Russian"
|
1417 |
|
1418 |
# move to correct position
|
1419 |
for example in examples:
|
1420 |
+
example += [chat, '', '', 'Disabled', k, ['All']]
|
1421 |
# adjust examples if non-chat mode
|
1422 |
if not chat:
|
1423 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
|
|
1489 |
return score
|
1490 |
|
1491 |
|
1492 |
+
def check_locals(**kwargs):
|
1493 |
+
# ensure everything in evaluate is here
|
1494 |
+
can_skip_because_locally_generated = [ # evaluate
|
1495 |
+
'instruction',
|
1496 |
+
'iinput',
|
1497 |
+
'context',
|
1498 |
+
'instruction_nochat',
|
1499 |
+
'iinput_nochat',
|
1500 |
+
# get_model:
|
1501 |
+
'reward_type'
|
1502 |
+
]
|
1503 |
+
for k in eval_func_param_names:
|
1504 |
+
if k in can_skip_because_locally_generated:
|
1505 |
+
continue
|
1506 |
+
assert k in kwargs, "Missing %s" % k
|
1507 |
+
for k in inputs_kwargs_list:
|
1508 |
+
if k in can_skip_because_locally_generated:
|
1509 |
+
continue
|
1510 |
+
assert k in kwargs, "Missing %s" % k
|
1511 |
+
|
1512 |
+
for k in list(inspect.signature(get_model).parameters):
|
1513 |
+
if k in can_skip_because_locally_generated:
|
1514 |
+
continue
|
1515 |
+
assert k in kwargs, "Missing %s" % k
|
1516 |
+
|
1517 |
+
|
1518 |
if __name__ == "__main__":
|
1519 |
"""
|
1520 |
Examples:
|
gpt4all_llm.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import inspect
|
2 |
import os
|
|
|
3 |
from typing import Dict, Any, Optional, List
|
4 |
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
5 |
from pydantic import root_validator
|
@@ -21,11 +22,11 @@ class FakeTokenizer:
|
|
21 |
|
22 |
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
23 |
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
24 |
-
model_kwargs = dict(
|
25 |
-
n_threads=os.cpu_count() // 2,
|
26 |
temp=kwargs.get('temperature', 0.2),
|
27 |
top_p=kwargs.get('top_p', 0.75),
|
28 |
-
top_k=kwargs.get('top_k', 40)
|
|
|
29 |
env_gpt4all_file = ".env_gpt4all"
|
30 |
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
31 |
|
@@ -33,43 +34,103 @@ def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
|
33 |
if 'model_path_llama' not in model_kwargs:
|
34 |
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
|
35 |
model_path = model_kwargs.pop('model_path_llama')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
from gpt4all import GPT4All as GPT4AllModel
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
from gpt4all import GPT4All as GPT4AllModel
|
|
|
42 |
else:
|
43 |
raise ValueError("No such base_model %s" % base_model)
|
44 |
-
func_names = list(inspect.signature(GPT4AllModel).parameters)
|
45 |
-
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
46 |
-
model = GPT4AllModel(model_path, **model_kwargs)
|
47 |
return model, FakeTokenizer(), 'cpu'
|
48 |
|
49 |
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
max_new_tokens=256,
|
52 |
temperature=0.1,
|
53 |
repetition_penalty=1.0,
|
54 |
top_k=40,
|
55 |
-
top_p=0.7
|
|
|
56 |
env_gpt4all_file = ".env_gpt4all"
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
if model_name == 'llama':
|
65 |
-
|
66 |
-
model_path =
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
else:
|
69 |
-
|
70 |
-
llm = H2OGPT4All(model=model_path, backend='gptj', callbacks=callbacks,
|
71 |
-
verbose=False, **default_params,
|
72 |
-
)
|
73 |
return llm
|
74 |
|
75 |
|
@@ -117,3 +178,78 @@ class H2OGPT4All(gpt4all.GPT4All):
|
|
117 |
if verbose:
|
118 |
print("_call prompt: %s" % prompt, flush=True)
|
119 |
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import inspect
|
2 |
import os
|
3 |
+
import sys
|
4 |
from typing import Dict, Any, Optional, List
|
5 |
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
6 |
from pydantic import root_validator
|
|
|
22 |
|
23 |
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
24 |
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
25 |
+
model_kwargs = dict(n_threads=os.cpu_count() // 2,
|
|
|
26 |
temp=kwargs.get('temperature', 0.2),
|
27 |
top_p=kwargs.get('top_p', 0.75),
|
28 |
+
top_k=kwargs.get('top_k', 40),
|
29 |
+
n_ctx=2048 - 256)
|
30 |
env_gpt4all_file = ".env_gpt4all"
|
31 |
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
32 |
|
|
|
34 |
if 'model_path_llama' not in model_kwargs:
|
35 |
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
|
36 |
model_path = model_kwargs.pop('model_path_llama')
|
37 |
+
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
|
38 |
+
from llama_cpp import Llama
|
39 |
+
# llama sets some things at init model time, not generation time
|
40 |
+
func_names = list(inspect.signature(Llama.__init__).parameters)
|
41 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
42 |
+
model_kwargs['n_ctx'] = int(model_kwargs['n_ctx'])
|
43 |
+
model = Llama(model_path=model_path, **model_kwargs)
|
44 |
+
elif base_model in "gpt4all_llama":
|
45 |
+
if 'model_name_gpt4all_llama' not in model_kwargs and 'model_path_gpt4all_llama' not in model_kwargs:
|
46 |
+
raise ValueError("No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" % env_gpt4all_file)
|
47 |
+
model_name = model_kwargs.pop('model_name_gpt4all_llama')
|
48 |
+
model_type = 'llama'
|
49 |
from gpt4all import GPT4All as GPT4AllModel
|
50 |
+
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
51 |
+
elif base_model in "gptj":
|
52 |
+
if 'model_name_gptj' not in model_kwargs and 'model_path_gptj' not in model_kwargs:
|
53 |
+
raise ValueError("No model_name_gpt4j or model_path_gpt4j in %s" % env_gpt4all_file)
|
54 |
+
model_name = model_kwargs.pop('model_name_gptj')
|
55 |
+
model_type = 'gptj'
|
56 |
from gpt4all import GPT4All as GPT4AllModel
|
57 |
+
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
58 |
else:
|
59 |
raise ValueError("No such base_model %s" % base_model)
|
|
|
|
|
|
|
60 |
return model, FakeTokenizer(), 'cpu'
|
61 |
|
62 |
|
63 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
64 |
+
|
65 |
+
|
66 |
+
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
67 |
+
|
68 |
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
69 |
+
"""Run on new LLM token. Only available when streaming is enabled."""
|
70 |
+
# streaming to std already occurs without this
|
71 |
+
# sys.stdout.write(token)
|
72 |
+
# sys.stdout.flush()
|
73 |
+
pass
|
74 |
+
|
75 |
+
|
76 |
+
def get_model_kwargs(env_kwargs, default_kwargs, cls):
|
77 |
+
# default from class
|
78 |
+
model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
|
79 |
+
# from our defaults
|
80 |
+
model_kwargs.update(default_kwargs)
|
81 |
+
# from user defaults
|
82 |
+
model_kwargs.update(env_kwargs)
|
83 |
+
# ensure only valid keys
|
84 |
+
func_names = list(inspect.signature(cls).parameters)
|
85 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
86 |
+
return model_kwargs
|
87 |
+
|
88 |
+
|
89 |
+
def get_llm_gpt4all(model_name,
|
90 |
+
model=None,
|
91 |
max_new_tokens=256,
|
92 |
temperature=0.1,
|
93 |
repetition_penalty=1.0,
|
94 |
top_k=40,
|
95 |
+
top_p=0.7,
|
96 |
+
verbose=False):
|
97 |
env_gpt4all_file = ".env_gpt4all"
|
98 |
+
env_kwargs = dotenv_values(env_gpt4all_file)
|
99 |
+
callbacks = [H2OStreamingStdOutCallbackHandler()]
|
100 |
+
n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
|
101 |
+
default_kwargs = dict(context_erase=0.5,
|
102 |
+
n_batch=1,
|
103 |
+
n_ctx=n_ctx,
|
104 |
+
n_predict=max_new_tokens,
|
105 |
+
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
106 |
+
repeat_penalty=repetition_penalty,
|
107 |
+
temp=temperature,
|
108 |
+
temperature=temperature,
|
109 |
+
top_k=top_k,
|
110 |
+
top_p=top_p,
|
111 |
+
use_mlock=True,
|
112 |
+
verbose=verbose)
|
113 |
if model_name == 'llama':
|
114 |
+
cls = H2OLlamaCpp
|
115 |
+
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
116 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
117 |
+
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
|
118 |
+
llm = cls(**model_kwargs)
|
119 |
+
llm.client.verbose = verbose
|
120 |
+
elif model_name == 'gpt4all_llama':
|
121 |
+
cls = H2OGPT4All
|
122 |
+
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
123 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
124 |
+
model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks))
|
125 |
+
llm = cls(**model_kwargs)
|
126 |
+
elif model_name == 'gptj':
|
127 |
+
cls = H2OGPT4All
|
128 |
+
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
129 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
130 |
+
model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks))
|
131 |
+
llm = cls(**model_kwargs)
|
132 |
else:
|
133 |
+
raise RuntimeError("No such model_name %s" % model_name)
|
|
|
|
|
|
|
134 |
return llm
|
135 |
|
136 |
|
|
|
178 |
if verbose:
|
179 |
print("_call prompt: %s" % prompt, flush=True)
|
180 |
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
181 |
+
|
182 |
+
|
183 |
+
from langchain.llms import LlamaCpp
|
184 |
+
|
185 |
+
|
186 |
+
class H2OLlamaCpp(LlamaCpp):
|
187 |
+
model_path: Any
|
188 |
+
"""Path to the pre-trained GPT4All model file."""
|
189 |
+
|
190 |
+
@root_validator()
|
191 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
192 |
+
"""Validate that llama-cpp-python library is installed."""
|
193 |
+
if isinstance(values["model_path"], str):
|
194 |
+
model_path = values["model_path"]
|
195 |
+
model_param_names = [
|
196 |
+
"lora_path",
|
197 |
+
"lora_base",
|
198 |
+
"n_ctx",
|
199 |
+
"n_parts",
|
200 |
+
"seed",
|
201 |
+
"f16_kv",
|
202 |
+
"logits_all",
|
203 |
+
"vocab_only",
|
204 |
+
"use_mlock",
|
205 |
+
"n_threads",
|
206 |
+
"n_batch",
|
207 |
+
"use_mmap",
|
208 |
+
"last_n_tokens_size",
|
209 |
+
]
|
210 |
+
model_params = {k: values[k] for k in model_param_names}
|
211 |
+
# For backwards compatibility, only include if non-null.
|
212 |
+
if values["n_gpu_layers"] is not None:
|
213 |
+
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
214 |
+
|
215 |
+
try:
|
216 |
+
from llama_cpp import Llama
|
217 |
+
|
218 |
+
values["client"] = Llama(model_path, **model_params)
|
219 |
+
except ImportError:
|
220 |
+
raise ModuleNotFoundError(
|
221 |
+
"Could not import llama-cpp-python library. "
|
222 |
+
"Please install the llama-cpp-python library to "
|
223 |
+
"use this embedding model: pip install llama-cpp-python"
|
224 |
+
)
|
225 |
+
except Exception as e:
|
226 |
+
raise ValueError(
|
227 |
+
f"Could not load Llama model from path: {model_path}. "
|
228 |
+
f"Received error {e}"
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
values["client"] = values["model_path"]
|
232 |
+
return values
|
233 |
+
|
234 |
+
def _call(
|
235 |
+
self,
|
236 |
+
prompt: str,
|
237 |
+
stop: Optional[List[str]] = None,
|
238 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
239 |
+
) -> str:
|
240 |
+
verbose = False
|
241 |
+
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
242 |
+
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
243 |
+
num_prompt_tokens = len(prompt_tokens)
|
244 |
+
if num_prompt_tokens > self.n_ctx:
|
245 |
+
# conservative by using int()
|
246 |
+
chars_per_token = int(len(prompt) / num_prompt_tokens)
|
247 |
+
prompt = prompt[-self.n_ctx * chars_per_token:]
|
248 |
+
if verbose:
|
249 |
+
print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
|
250 |
+
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
251 |
+
num_prompt_tokens2 = len(prompt_tokens2)
|
252 |
+
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
253 |
+
if verbose:
|
254 |
+
print("_call prompt: %s" % prompt, flush=True)
|
255 |
+
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
gpt_langchain.py
CHANGED
@@ -3,6 +3,7 @@ import inspect
|
|
3 |
import os
|
4 |
import pathlib
|
5 |
import pickle
|
|
|
6 |
import shutil
|
7 |
import subprocess
|
8 |
import sys
|
@@ -16,9 +17,11 @@ from functools import reduce
|
|
16 |
from operator import concat
|
17 |
|
18 |
from joblib import Parallel, delayed
|
|
|
19 |
|
|
|
20 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
21 |
-
get_device
|
22 |
|
23 |
import_matplotlib()
|
24 |
|
@@ -35,7 +38,6 @@ from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, Pytho
|
|
35 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
36 |
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
|
37 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
38 |
-
from langchain.vectorstores import FAISS
|
39 |
from langchain.chains.question_answering import load_qa_chain
|
40 |
from langchain.docstore.document import Document
|
41 |
from langchain import PromptTemplate
|
@@ -43,17 +45,36 @@ from langchain.vectorstores import Chroma
|
|
43 |
|
44 |
|
45 |
def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
|
|
|
46 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
|
47 |
if not sources:
|
48 |
return None
|
49 |
# get embedding model
|
50 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
|
|
|
|
|
|
51 |
|
52 |
# Create vector database
|
53 |
if db_type == 'faiss':
|
|
|
54 |
db = FAISS.from_documents(sources, embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
elif db_type == 'chroma':
|
56 |
-
|
57 |
os.makedirs(persist_directory, exist_ok=True)
|
58 |
db = Chroma.from_documents(documents=sources,
|
59 |
embedding=embedding,
|
@@ -61,34 +82,121 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo
|
|
61 |
collection_name=collection_name,
|
62 |
anonymized_telemetry=False)
|
63 |
db.persist()
|
64 |
-
# FIXME: below just proves can load persistent dir, regenerates its embedding files, so a bit wasteful
|
65 |
-
if False:
|
66 |
-
db = Chroma(embedding_function=embedding,
|
67 |
-
persist_directory=persist_directory,
|
68 |
-
collection_name=collection_name)
|
69 |
else:
|
70 |
raise RuntimeError("No such db_type=%s" % db_type)
|
71 |
|
72 |
return db
|
73 |
|
74 |
|
75 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
if not sources:
|
77 |
-
return db
|
78 |
if db_type == 'faiss':
|
79 |
db.add_documents(sources)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
elif db_type == 'chroma':
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
db.add_documents(documents=sources)
|
88 |
db.persist()
|
89 |
else:
|
90 |
raise RuntimeError("No such db_type=%s" % db_type)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
return db
|
93 |
|
94 |
|
@@ -126,19 +234,23 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
126 |
top_k=40,
|
127 |
top_p=0.7,
|
128 |
prompt_type=None,
|
|
|
|
|
129 |
):
|
130 |
if use_openai_model:
|
131 |
from langchain.llms import OpenAI
|
132 |
llm = OpenAI(temperature=0)
|
133 |
model_name = 'openai'
|
134 |
streamer = None
|
135 |
-
|
|
|
136 |
from gpt4all_llm import get_llm_gpt4all
|
137 |
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
|
138 |
temperature=temperature,
|
139 |
repetition_penalty=repetition_penalty,
|
140 |
top_k=top_k,
|
141 |
top_p=top_p,
|
|
|
142 |
)
|
143 |
streamer = None
|
144 |
prompt_type = 'plain'
|
@@ -149,6 +261,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
149 |
# only used if didn't pass model in
|
150 |
assert model_name is None
|
151 |
assert tokenizer is None
|
|
|
152 |
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
153 |
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
154 |
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
@@ -165,7 +278,12 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
165 |
torch_dtype=torch_dtype,
|
166 |
load_in_8bit=load_8bit)
|
167 |
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
169 |
if stream_output:
|
170 |
skip_prompt = False
|
171 |
from generate import H2OTextIteratorStreamer
|
@@ -175,17 +293,19 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
175 |
else:
|
176 |
streamer = None
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
189 |
|
190 |
from langchain.llms import HuggingFacePipeline
|
191 |
llm = HuggingFacePipeline(pipeline=pipe)
|
@@ -341,6 +461,12 @@ try:
|
|
341 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
342 |
have_arxiv = False
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
image_types = ["png", "jpg", "jpeg"]
|
345 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
346 |
"md", "html",
|
@@ -357,9 +483,10 @@ file_types = non_image_types + image_types
|
|
357 |
|
358 |
def add_meta(docs1, file):
|
359 |
file_extension = pathlib.Path(file).suffix
|
|
|
360 |
if not isinstance(docs1, list):
|
361 |
docs1 = [docs1]
|
362 |
-
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now))) for x in docs1]
|
363 |
|
364 |
|
365 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
|
@@ -409,42 +536,45 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
409 |
f.write(file)
|
410 |
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
|
411 |
doc1 = Document(page_content=file, metadata=metadata)
|
412 |
-
elif file.endswith('.html') or file.endswith('.mhtml'):
|
413 |
docs1 = UnstructuredHTMLLoader(file_path=file).load()
|
414 |
add_meta(docs1, file)
|
415 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
416 |
-
elif (file.endswith('.docx') or file.endswith('.doc')) and have_libreoffice:
|
417 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
418 |
add_meta(docs1, file)
|
419 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
420 |
-
elif file.endswith('.odt'):
|
421 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
422 |
add_meta(docs1, file)
|
423 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
424 |
-
elif file.endswith('pptx') or file.endswith('ppt'):
|
425 |
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
|
426 |
add_meta(docs1, file)
|
427 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
428 |
-
elif file.endswith('.txt'):
|
429 |
# use UnstructuredFileLoader ?
|
430 |
-
|
|
|
|
|
431 |
add_meta(doc1, file)
|
432 |
-
elif file.endswith('.rtf'):
|
433 |
docs1 = UnstructuredRTFLoader(file).load()
|
434 |
add_meta(docs1, file)
|
435 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
436 |
-
elif file.endswith('.md'):
|
437 |
docs1 = UnstructuredMarkdownLoader(file).load()
|
438 |
add_meta(docs1, file)
|
439 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
440 |
-
elif file.endswith('.enex'):
|
441 |
-
|
442 |
add_meta(doc1, file)
|
443 |
-
|
|
|
444 |
docs1 = UnstructuredEPubLoader(file).load()
|
445 |
add_meta(docs1, file)
|
446 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
447 |
-
elif file.endswith('.jpeg') or file.endswith('.jpg') or file.endswith('.png'):
|
448 |
docs1 = []
|
449 |
if have_tesseract and enable_ocr:
|
450 |
# OCR, somewhat works, but not great
|
@@ -471,13 +601,14 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
471 |
docs1.extend(docs1c)
|
472 |
for doci in docs1:
|
473 |
doci.metadata['source'] = doci.metadata['image_path']
|
|
|
474 |
if docs1:
|
475 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
476 |
-
elif file.endswith('.msg'):
|
477 |
raise RuntimeError("Not supported, GPL3 license")
|
478 |
# docs1 = OutlookMessageLoader(file).load()
|
479 |
# docs1[0].metadata['source'] = file
|
480 |
-
elif file.endswith('.eml'):
|
481 |
try:
|
482 |
docs1 = UnstructuredEmailLoader(file).load()
|
483 |
add_meta(docs1, file)
|
@@ -491,34 +622,43 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
491 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
492 |
else:
|
493 |
raise
|
494 |
-
# elif file.endswith('.gcsdir'):
|
495 |
# doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
|
496 |
-
# elif file.endswith('.gcsfile'):
|
497 |
# doc1 = GCSFileLoader(project_name, bucket, blob).load()
|
498 |
-
elif file.endswith('.rst'):
|
499 |
with open(file, "r") as f:
|
500 |
doc1 = Document(page_content=f.read(), metadata={"source": file})
|
501 |
add_meta(doc1, file)
|
502 |
-
elif file.endswith('.pdf'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
504 |
-
# e.g. Beyond fine-tuning_ Classifying high resolution mammograms using function-preserving transformations _ Elsevier Enhanced Reader.pdf
|
505 |
-
doc1 = PyPDFLoader(file).load_and_split()
|
506 |
add_meta(doc1, file)
|
507 |
-
elif file.endswith('.csv'):
|
508 |
doc1 = CSVLoader(file).load()
|
509 |
add_meta(doc1, file)
|
510 |
-
elif file.endswith('.py'):
|
511 |
doc1 = PythonLoader(file).load()
|
512 |
add_meta(doc1, file)
|
513 |
-
elif file.endswith('.toml'):
|
514 |
doc1 = TomlLoader(file).load()
|
515 |
add_meta(doc1, file)
|
516 |
-
elif file.endswith('.urls'):
|
517 |
with open(file, "r") as f:
|
518 |
docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
|
519 |
add_meta(docs1, file)
|
520 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
521 |
-
elif file.endswith('.zip'):
|
522 |
with zipfile.ZipFile(file, 'r') as zip_ref:
|
523 |
# don't put into temporary path, since want to keep references to docs inside zip
|
524 |
# so just extract in path where
|
@@ -529,11 +669,17 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
529 |
raise RuntimeError("No file handler for %s" % os.path.basename(file))
|
530 |
|
531 |
# allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
|
|
|
532 |
if not isinstance(doc1, list):
|
533 |
if chunk:
|
534 |
docs = chunk_sources([doc1], chunk_size=chunk_size)
|
535 |
else:
|
536 |
docs = [doc1]
|
|
|
|
|
|
|
|
|
|
|
537 |
else:
|
538 |
docs = doc1
|
539 |
|
@@ -590,6 +736,8 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
590 |
captions_model=None,
|
591 |
caption_loader=None,
|
592 |
enable_ocr=False,
|
|
|
|
|
593 |
):
|
594 |
globs_image_types = []
|
595 |
globs_non_image_types = []
|
@@ -617,6 +765,28 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
617 |
# But instead, allow fail so can collect unsupported too
|
618 |
set_globs_image_types = set(globs_image_types)
|
619 |
globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
# could use generator, but messes up metadata handling in recursive case
|
621 |
if caption_loader and not isinstance(caption_loader, (bool, str)) and \
|
622 |
caption_loader.device != 'cpu' or \
|
@@ -643,21 +813,21 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
643 |
if n_jobs != 1 and len(globs_non_image_types) > 1:
|
644 |
# avoid nesting, e.g. upload 1 zip and then inside many files
|
645 |
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
|
646 |
-
documents =
|
647 |
delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
|
648 |
)
|
649 |
else:
|
650 |
-
documents = [path_to_doc1(file, **kwargs) for file in globs_non_image_types]
|
651 |
|
652 |
# do images separately since can't fork after cuda in parent, so can't be parallel
|
653 |
if n_jobs_image != 1 and len(globs_image_types) > 1:
|
654 |
# avoid nesting, e.g. upload 1 zip and then inside many files
|
655 |
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
|
656 |
-
image_documents =
|
657 |
delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
|
658 |
)
|
659 |
else:
|
660 |
-
image_documents = [path_to_doc1(file, **kwargs) for file in globs_image_types]
|
661 |
|
662 |
# add image docs in
|
663 |
documents += image_documents
|
@@ -676,7 +846,9 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
676 |
return documents
|
677 |
|
678 |
|
679 |
-
def prep_langchain(persist_directory,
|
|
|
|
|
680 |
hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
|
681 |
"""
|
682 |
do prep first time, involving downloads
|
@@ -685,12 +857,18 @@ def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_emb
|
|
685 |
"""
|
686 |
assert langchain_mode not in ['MyData'], "Should not prep scratch data"
|
687 |
|
688 |
-
|
|
|
|
|
689 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
690 |
db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
691 |
hf_embedding_model)
|
692 |
else:
|
693 |
-
|
|
|
|
|
|
|
|
|
694 |
db = None
|
695 |
if langchain_mode in ['All', 'DriverlessAI docs']:
|
696 |
# FIXME: Could also just use dai_docs.pickle directly and upload that
|
@@ -701,19 +879,52 @@ def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_emb
|
|
701 |
|
702 |
langchain_kwargs = kwargs_make_db.copy()
|
703 |
langchain_kwargs.update(locals())
|
704 |
-
db = make_db(**langchain_kwargs)
|
705 |
|
706 |
return db
|
707 |
|
708 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
709 |
def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
710 |
hf_embedding_model):
|
711 |
if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
712 |
os.path.join(persist_directory, 'index')):
|
713 |
print("DO Loading db: %s" % langchain_mode, flush=True)
|
714 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
|
|
|
|
|
|
|
|
715 |
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
716 |
-
collection_name=langchain_mode.replace(' ', '_')
|
|
|
717 |
print("DONE Loading db: %s" % langchain_mode, flush=True)
|
718 |
return db
|
719 |
return None
|
@@ -740,21 +951,40 @@ def _make_db(use_openai_embedding=False,
|
|
740 |
langchain_mode=None,
|
741 |
user_path=None,
|
742 |
db_type='faiss',
|
743 |
-
load_db_if_exists=
|
744 |
db=None,
|
745 |
-
n_jobs=-1
|
|
|
746 |
persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
|
747 |
if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
748 |
os.path.join(persist_directory, 'index')):
|
749 |
assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
|
750 |
-
print("Loading db", flush=True)
|
751 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
|
|
|
|
|
|
|
|
752 |
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
753 |
-
collection_name=langchain_mode.replace(' ', '_')
|
754 |
-
|
|
|
|
|
|
|
|
|
|
|
755 |
assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
|
756 |
-
|
757 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
if langchain_mode in ['wiki_full', 'All', "'All'"]:
|
759 |
from read_wiki_full import get_all_documents
|
760 |
small_test = None
|
@@ -783,9 +1013,25 @@ def _make_db(use_openai_embedding=False,
|
|
783 |
sources.extend(sources1)
|
784 |
if langchain_mode in ['All', 'UserData']:
|
785 |
if user_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
786 |
# chunk internally for speed over multiple docs
|
787 |
-
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size
|
|
|
|
|
|
|
|
|
|
|
|
|
788 |
sources.extend(sources1)
|
|
|
789 |
else:
|
790 |
print("Chose UserData but user_path is empty/None", flush=True)
|
791 |
if False and langchain_mode in ['urls', 'All', "'All'"]:
|
@@ -797,14 +1043,48 @@ def _make_db(use_openai_embedding=False,
|
|
797 |
sources1 = loader.load()
|
798 |
sources.extend(sources1)
|
799 |
if not sources:
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
808 |
|
809 |
|
810 |
source_prefix = "Sources [Score | Link]:"
|
@@ -828,6 +1108,7 @@ def _run_qa_db(query=None,
|
|
828 |
use_openai_model=False, use_openai_embedding=False,
|
829 |
first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
|
830 |
user_path=None,
|
|
|
831 |
db_type='faiss',
|
832 |
model_name=None, model=None, tokenizer=None,
|
833 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
@@ -847,7 +1128,9 @@ def _run_qa_db(query=None,
|
|
847 |
top_p=0.7,
|
848 |
langchain_mode=None,
|
849 |
document_choice=['All'],
|
850 |
-
n_jobs=-1
|
|
|
|
|
851 |
"""
|
852 |
|
853 |
:param query:
|
@@ -859,17 +1142,19 @@ def _run_qa_db(query=None,
|
|
859 |
:param chunk:
|
860 |
:param chunk_size:
|
861 |
:param user_path: user path to glob recursively from
|
862 |
-
:param db_type: 'faiss' for in-memory db or 'chroma' for persistent db
|
863 |
:param model_name: model name, used to switch behaviors
|
864 |
:param model: pre-initialized model, else will make new one
|
865 |
:param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
|
866 |
:param answer_with_sources
|
867 |
:return:
|
868 |
"""
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
|
|
|
|
873 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
874 |
model=model, tokenizer=tokenizer,
|
875 |
stream_output=stream_output,
|
@@ -879,74 +1164,173 @@ def _run_qa_db(query=None,
|
|
879 |
top_k=top_k,
|
880 |
top_p=top_p,
|
881 |
prompt_type=prompt_type,
|
|
|
|
|
882 |
)
|
883 |
|
884 |
-
if model_name in
|
885 |
# FIXME: for now, streams to stdout/stderr currently
|
886 |
stream_output = False
|
887 |
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
892 |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
893 |
use_context = False
|
894 |
-
template = """%s{context}{question}""" % prefix
|
895 |
else:
|
896 |
use_context = True
|
897 |
-
template = """%s
|
898 |
-
==
|
899 |
-
{context}
|
900 |
-
==
|
901 |
-
{question}""" % prefix
|
902 |
-
prompt = PromptTemplate(
|
903 |
-
# input_variables=["summaries", "question"],
|
904 |
-
input_variables=["context", "question"],
|
905 |
-
template=template,
|
906 |
-
)
|
907 |
-
chain = load_qa_chain(llm, prompt=prompt)
|
908 |
else:
|
909 |
-
chain = load_qa_with_sources_chain(llm)
|
910 |
use_context = True
|
911 |
|
912 |
-
if query is None:
|
913 |
-
query = "What are the main differences between Linux and Windows?"
|
914 |
# https://github.com/hwchase17/langchain/issues/1946
|
915 |
# FIXME: Seems to way to get size of chroma db to limit k to avoid
|
916 |
# Chroma collection MyData contains fewer than 4 elements.
|
917 |
# type logger error
|
918 |
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
|
919 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
920 |
if db and use_context:
|
921 |
if isinstance(document_choice, str):
|
922 |
# support string as well
|
923 |
document_choice = [document_choice]
|
924 |
-
if not isinstance(db, Chroma) or
|
|
|
|
|
925 |
# treat empty list as All for now, not 'None'
|
926 |
filter_kwargs = {}
|
|
|
|
|
|
|
927 |
else:
|
928 |
if len(document_choice) >= 2:
|
929 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
930 |
filter_kwargs = dict(filter={"$or": or_filter})
|
931 |
-
|
932 |
one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
|
933 |
filter_kwargs = dict(filter=one_filter)
|
934 |
-
|
|
|
|
|
935 |
k_db = 1
|
936 |
k = 0
|
937 |
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
|
938 |
# cut off so no high distance docs/sources considered
|
939 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
940 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
941 |
-
if len(scores) > 0:
|
942 |
print("Distance: min: %s max: %s mean: %s median: %s" %
|
943 |
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
944 |
else:
|
945 |
docs = []
|
946 |
scores = []
|
947 |
|
948 |
-
if not docs and use_context:
|
949 |
-
|
|
|
|
|
|
|
|
|
|
|
950 |
|
951 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
952 |
if os.path.isfile(common_words_file):
|
@@ -958,88 +1342,82 @@ def _run_qa_db(query=None,
|
|
958 |
num_common = len([x.lower() in set_common for x in reduced_query_words])
|
959 |
frac_common = num_common / len(reduced_query) if reduced_query else 0
|
960 |
# FIXME: report to user bad query that uses too many common words
|
961 |
-
|
|
|
|
|
|
|
|
|
|
|
962 |
|
963 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
964 |
chain_kwargs = dict(input_documents=[], question=query)
|
965 |
else:
|
966 |
chain_kwargs = dict(input_documents=docs, question=query)
|
967 |
|
968 |
-
|
969 |
-
|
970 |
-
assert streamer is not None
|
971 |
-
target = wrapped_partial(chain, chain_kwargs)
|
972 |
-
import queue
|
973 |
-
bucket = queue.Queue()
|
974 |
-
thread = EThread(target=target, streamer=streamer, bucket=bucket)
|
975 |
-
thread.start()
|
976 |
-
outputs = ""
|
977 |
-
prompt = None # FIXME
|
978 |
-
try:
|
979 |
-
for new_text in streamer:
|
980 |
-
# print("new_text: %s" % new_text, flush=True)
|
981 |
-
if bucket.qsize() > 0 or thread.exc:
|
982 |
-
thread.join()
|
983 |
-
outputs += new_text
|
984 |
-
if prompter: # and False: # FIXME: pipeline can already use prompter
|
985 |
-
output1 = prompter.get_response(outputs, prompt=prompt,
|
986 |
-
sanitize_bot_response=sanitize_bot_response)
|
987 |
-
yield output1
|
988 |
-
else:
|
989 |
-
yield outputs
|
990 |
-
except BaseException:
|
991 |
-
# if any exception, raise that exception if was from thread, first
|
992 |
-
if thread.exc:
|
993 |
-
raise thread.exc
|
994 |
-
raise
|
995 |
-
finally:
|
996 |
-
# in case no exception and didn't join with thread yet, then join
|
997 |
-
if not thread.exc:
|
998 |
-
answer = thread.join()
|
999 |
-
# in case raise StopIteration or broke queue loop in streamer, but still have exception
|
1000 |
-
if thread.exc:
|
1001 |
-
raise thread.exc
|
1002 |
-
# FIXME: answer is not string outputs from streamer. How to get actual final output?
|
1003 |
-
# answer = outputs
|
1004 |
-
else:
|
1005 |
-
answer = chain(chain_kwargs)
|
1006 |
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
elif answer is not None:
|
1011 |
print("query: %s" % query, flush=True)
|
1012 |
print("answer: %s" % answer['output_text'], flush=True)
|
1013 |
-
# link
|
1014 |
-
answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
|
1015 |
-
zip(scores, answer['input_documents'])]
|
1016 |
-
answer_sources_dict = defaultdict(list)
|
1017 |
-
[answer_sources_dict[url].append(score) for score, url in answer_sources]
|
1018 |
-
answers_dict = {}
|
1019 |
-
for url, scores_url in answer_sources_dict.items():
|
1020 |
-
answers_dict[url] = np.max(scores_url)
|
1021 |
-
answer_sources = [(score, url) for url, score in answers_dict.items()]
|
1022 |
-
answer_sources.sort(key=lambda x: x[0], reverse=True)
|
1023 |
-
if show_rank:
|
1024 |
-
# answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
|
1025 |
-
# sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
|
1026 |
-
answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
|
1027 |
-
sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
|
1028 |
-
else:
|
1029 |
-
answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
|
1030 |
-
sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
|
1031 |
-
sorted_sources_urls += f"</ul></p>{source_postfix}"
|
1032 |
|
1033 |
-
|
1034 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1035 |
|
1036 |
-
|
1037 |
-
|
1038 |
-
else:
|
1039 |
-
ret = answer['output_text']
|
1040 |
|
1041 |
-
|
1042 |
-
|
|
|
|
|
|
|
|
|
1043 |
|
1044 |
|
1045 |
def chunk_sources(sources, chunk_size=1024):
|
|
|
3 |
import os
|
4 |
import pathlib
|
5 |
import pickle
|
6 |
+
import queue
|
7 |
import shutil
|
8 |
import subprocess
|
9 |
import sys
|
|
|
17 |
from operator import concat
|
18 |
|
19 |
from joblib import Parallel, delayed
|
20 |
+
from tqdm import tqdm
|
21 |
|
22 |
+
from prompter import non_hf_types
|
23 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
24 |
+
get_device, ProgressParallel, remove, hash_file
|
25 |
|
26 |
import_matplotlib()
|
27 |
|
|
|
38 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
39 |
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
|
40 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
41 |
from langchain.chains.question_answering import load_qa_chain
|
42 |
from langchain.docstore.document import Document
|
43 |
from langchain import PromptTemplate
|
|
|
45 |
|
46 |
|
47 |
def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
|
48 |
+
collection_name=None,
|
49 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
|
50 |
if not sources:
|
51 |
return None
|
52 |
# get embedding model
|
53 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
54 |
+
assert collection_name is not None or langchain_mode != 'notset'
|
55 |
+
if collection_name is None:
|
56 |
+
collection_name = langchain_mode.replace(' ', '_')
|
57 |
|
58 |
# Create vector database
|
59 |
if db_type == 'faiss':
|
60 |
+
from langchain.vectorstores import FAISS
|
61 |
db = FAISS.from_documents(sources, embedding)
|
62 |
+
|
63 |
+
elif db_type == 'weaviate':
|
64 |
+
import weaviate
|
65 |
+
from weaviate.embedded import EmbeddedOptions
|
66 |
+
from langchain.vectorstores import Weaviate
|
67 |
+
|
68 |
+
# TODO: add support for connecting via docker compose
|
69 |
+
client = weaviate.Client(
|
70 |
+
embedded_options=EmbeddedOptions()
|
71 |
+
)
|
72 |
+
index_name = collection_name.capitalize()
|
73 |
+
db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
|
74 |
+
index_name=index_name)
|
75 |
+
|
76 |
elif db_type == 'chroma':
|
77 |
+
assert persist_directory is not None
|
78 |
os.makedirs(persist_directory, exist_ok=True)
|
79 |
db = Chroma.from_documents(documents=sources,
|
80 |
embedding=embedding,
|
|
|
82 |
collection_name=collection_name,
|
83 |
anonymized_telemetry=False)
|
84 |
db.persist()
|
|
|
|
|
|
|
|
|
|
|
85 |
else:
|
86 |
raise RuntimeError("No such db_type=%s" % db_type)
|
87 |
|
88 |
return db
|
89 |
|
90 |
|
91 |
+
def _get_unique_sources_in_weaviate(db):
|
92 |
+
batch_size = 100
|
93 |
+
id_source_list = []
|
94 |
+
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)
|
95 |
+
|
96 |
+
while result['objects']:
|
97 |
+
id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
|
98 |
+
last_id = id_source_list[-1][0]
|
99 |
+
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)
|
100 |
+
|
101 |
+
unique_sources = {source for _, source in id_source_list}
|
102 |
+
return unique_sources
|
103 |
+
|
104 |
+
|
105 |
+
def add_to_db(db, sources, db_type='faiss',
|
106 |
+
avoid_dup_by_file=False,
|
107 |
+
avoid_dup_by_content=True):
|
108 |
+
num_new_sources = len(sources)
|
109 |
if not sources:
|
110 |
+
return db, num_new_sources, []
|
111 |
if db_type == 'faiss':
|
112 |
db.add_documents(sources)
|
113 |
+
elif db_type == 'weaviate':
|
114 |
+
# FIXME: only control by file name, not hash yet
|
115 |
+
if avoid_dup_by_file or avoid_dup_by_content:
|
116 |
+
unique_sources = _get_unique_sources_in_weaviate(db)
|
117 |
+
sources = [x for x in sources if x.metadata['source'] not in unique_sources]
|
118 |
+
num_new_sources = len(sources)
|
119 |
+
if num_new_sources == 0:
|
120 |
+
return db, num_new_sources, []
|
121 |
+
db.add_documents(documents=sources)
|
122 |
elif db_type == 'chroma':
|
123 |
+
collection = db.get()
|
124 |
+
# files we already have:
|
125 |
+
metadata_files = set([x['source'] for x in collection['metadatas']])
|
126 |
+
if avoid_dup_by_file:
|
127 |
+
# Too weak in case file changed content, assume parent shouldn't pass true for this for now
|
128 |
+
raise RuntimeError("Not desired code path")
|
129 |
+
sources = [x for x in sources if x.metadata['source'] not in metadata_files]
|
130 |
+
if avoid_dup_by_content:
|
131 |
+
# look at hash, instead of page_content
|
132 |
+
# migration: If no hash previously, avoid updating,
|
133 |
+
# since don't know if need to update and may be expensive to redo all unhashed files
|
134 |
+
metadata_hash_ids = set(
|
135 |
+
[x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
|
136 |
+
# avoid sources with same hash
|
137 |
+
sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
|
138 |
+
# get new file names that match existing file names. delete existing files we are overridding
|
139 |
+
dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
|
140 |
+
print("Removing %s duplicate files from db because ingesting those as new documents" % len(
|
141 |
+
dup_metadata_files), flush=True)
|
142 |
+
client_collection = db._client.get_collection(name=db._collection.name)
|
143 |
+
for dup_file in dup_metadata_files:
|
144 |
+
dup_file_meta = dict(source=dup_file)
|
145 |
+
try:
|
146 |
+
client_collection.delete(where=dup_file_meta)
|
147 |
+
except KeyError:
|
148 |
+
pass
|
149 |
+
num_new_sources = len(sources)
|
150 |
+
if num_new_sources == 0:
|
151 |
+
return db, num_new_sources, []
|
152 |
db.add_documents(documents=sources)
|
153 |
db.persist()
|
154 |
else:
|
155 |
raise RuntimeError("No such db_type=%s" % db_type)
|
156 |
|
157 |
+
new_sources_metadata = [x.metadata for x in sources]
|
158 |
+
|
159 |
+
return db, num_new_sources, new_sources_metadata
|
160 |
+
|
161 |
+
|
162 |
+
def create_or_update_db(db_type, persist_directory, collection_name,
|
163 |
+
sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model):
|
164 |
+
if db_type == 'weaviate':
|
165 |
+
import weaviate
|
166 |
+
from weaviate.embedded import EmbeddedOptions
|
167 |
+
|
168 |
+
# TODO: add support for connecting via docker compose
|
169 |
+
client = weaviate.Client(
|
170 |
+
embedded_options=EmbeddedOptions()
|
171 |
+
)
|
172 |
+
index_name = collection_name.replace(' ', '_').capitalize()
|
173 |
+
if client.schema.exists(index_name) and not add_if_exists:
|
174 |
+
client.schema.delete_class(index_name)
|
175 |
+
if verbose:
|
176 |
+
print("Removing %s" % index_name, flush=True)
|
177 |
+
elif db_type == 'chroma':
|
178 |
+
if not os.path.isdir(persist_directory) or not add_if_exists:
|
179 |
+
if os.path.isdir(persist_directory):
|
180 |
+
if verbose:
|
181 |
+
print("Removing %s" % persist_directory, flush=True)
|
182 |
+
remove(persist_directory)
|
183 |
+
if verbose:
|
184 |
+
print("Generating db", flush=True)
|
185 |
+
|
186 |
+
if not add_if_exists:
|
187 |
+
if verbose:
|
188 |
+
print("Generating db", flush=True)
|
189 |
+
else:
|
190 |
+
if verbose:
|
191 |
+
print("Loading and updating db", flush=True)
|
192 |
+
|
193 |
+
db = get_db(sources,
|
194 |
+
use_openai_embedding=use_openai_embedding,
|
195 |
+
db_type=db_type,
|
196 |
+
persist_directory=persist_directory,
|
197 |
+
langchain_mode=collection_name,
|
198 |
+
hf_embedding_model=hf_embedding_model)
|
199 |
+
|
200 |
return db
|
201 |
|
202 |
|
|
|
234 |
top_k=40,
|
235 |
top_p=0.7,
|
236 |
prompt_type=None,
|
237 |
+
prompter=None,
|
238 |
+
verbose=False,
|
239 |
):
|
240 |
if use_openai_model:
|
241 |
from langchain.llms import OpenAI
|
242 |
llm = OpenAI(temperature=0)
|
243 |
model_name = 'openai'
|
244 |
streamer = None
|
245 |
+
prompt_type = 'plain'
|
246 |
+
elif model_name in non_hf_types:
|
247 |
from gpt4all_llm import get_llm_gpt4all
|
248 |
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
|
249 |
temperature=temperature,
|
250 |
repetition_penalty=repetition_penalty,
|
251 |
top_k=top_k,
|
252 |
top_p=top_p,
|
253 |
+
verbose=verbose,
|
254 |
)
|
255 |
streamer = None
|
256 |
prompt_type = 'plain'
|
|
|
261 |
# only used if didn't pass model in
|
262 |
assert model_name is None
|
263 |
assert tokenizer is None
|
264 |
+
prompt_type = 'human_bot'
|
265 |
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
266 |
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
267 |
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
|
|
278 |
torch_dtype=torch_dtype,
|
279 |
load_in_8bit=load_8bit)
|
280 |
|
281 |
+
max_max_tokens = tokenizer.model_max_length
|
282 |
+
gen_kwargs = dict(max_new_tokens=max_new_tokens,
|
283 |
+
return_full_text=True,
|
284 |
+
early_stopping=False,
|
285 |
+
handle_long_generation='hole')
|
286 |
+
|
287 |
if stream_output:
|
288 |
skip_prompt = False
|
289 |
from generate import H2OTextIteratorStreamer
|
|
|
293 |
else:
|
294 |
streamer = None
|
295 |
|
296 |
+
from h2oai_pipeline import H2OTextGenerationPipeline
|
297 |
+
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
|
298 |
+
prompter=prompter,
|
299 |
+
prompt_type=prompt_type,
|
300 |
+
sanitize_bot_response=True,
|
301 |
+
chat=False, stream_output=stream_output,
|
302 |
+
tokenizer=tokenizer,
|
303 |
+
max_input_tokens=max_max_tokens - max_new_tokens,
|
304 |
+
**gen_kwargs)
|
305 |
+
# pipe.task = "text-generation"
|
306 |
+
# below makes it listen only to our prompt removal,
|
307 |
+
# not built in prompt removal that is less general and not specific for our model
|
308 |
+
pipe.task = "text2text-generation"
|
309 |
|
310 |
from langchain.llms import HuggingFacePipeline
|
311 |
llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
461 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
462 |
have_arxiv = False
|
463 |
|
464 |
+
try:
|
465 |
+
assert pkg_resources.get_distribution('pymupdf') is not None
|
466 |
+
have_pymupdf = True
|
467 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
468 |
+
have_pymupdf = False
|
469 |
+
|
470 |
image_types = ["png", "jpg", "jpeg"]
|
471 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
472 |
"md", "html",
|
|
|
483 |
|
484 |
def add_meta(docs1, file):
|
485 |
file_extension = pathlib.Path(file).suffix
|
486 |
+
hashid = hash_file(file)
|
487 |
if not isinstance(docs1, list):
|
488 |
docs1 = [docs1]
|
489 |
+
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
|
490 |
|
491 |
|
492 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
|
|
|
536 |
f.write(file)
|
537 |
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
|
538 |
doc1 = Document(page_content=file, metadata=metadata)
|
539 |
+
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
|
540 |
docs1 = UnstructuredHTMLLoader(file_path=file).load()
|
541 |
add_meta(docs1, file)
|
542 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
543 |
+
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
|
544 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
545 |
add_meta(docs1, file)
|
546 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
547 |
+
elif file.lower().endswith('.odt'):
|
548 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
549 |
add_meta(docs1, file)
|
550 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
551 |
+
elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
|
552 |
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
|
553 |
add_meta(docs1, file)
|
554 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
555 |
+
elif file.lower().endswith('.txt'):
|
556 |
# use UnstructuredFileLoader ?
|
557 |
+
docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
|
558 |
+
# makes just one, but big one
|
559 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
560 |
add_meta(doc1, file)
|
561 |
+
elif file.lower().endswith('.rtf'):
|
562 |
docs1 = UnstructuredRTFLoader(file).load()
|
563 |
add_meta(docs1, file)
|
564 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
565 |
+
elif file.lower().endswith('.md'):
|
566 |
docs1 = UnstructuredMarkdownLoader(file).load()
|
567 |
add_meta(docs1, file)
|
568 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
569 |
+
elif file.lower().endswith('.enex'):
|
570 |
+
docs1 = EverNoteLoader(file).load()
|
571 |
add_meta(doc1, file)
|
572 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
573 |
+
elif file.lower().endswith('.epub'):
|
574 |
docs1 = UnstructuredEPubLoader(file).load()
|
575 |
add_meta(docs1, file)
|
576 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
577 |
+
elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
|
578 |
docs1 = []
|
579 |
if have_tesseract and enable_ocr:
|
580 |
# OCR, somewhat works, but not great
|
|
|
601 |
docs1.extend(docs1c)
|
602 |
for doci in docs1:
|
603 |
doci.metadata['source'] = doci.metadata['image_path']
|
604 |
+
doci.metadata['hash'] = hash_file(doci.metadata['source'])
|
605 |
if docs1:
|
606 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
607 |
+
elif file.lower().endswith('.msg'):
|
608 |
raise RuntimeError("Not supported, GPL3 license")
|
609 |
# docs1 = OutlookMessageLoader(file).load()
|
610 |
# docs1[0].metadata['source'] = file
|
611 |
+
elif file.lower().endswith('.eml'):
|
612 |
try:
|
613 |
docs1 = UnstructuredEmailLoader(file).load()
|
614 |
add_meta(docs1, file)
|
|
|
622 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
623 |
else:
|
624 |
raise
|
625 |
+
# elif file.lower().endswith('.gcsdir'):
|
626 |
# doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
|
627 |
+
# elif file.lower().endswith('.gcsfile'):
|
628 |
# doc1 = GCSFileLoader(project_name, bucket, blob).load()
|
629 |
+
elif file.lower().endswith('.rst'):
|
630 |
with open(file, "r") as f:
|
631 |
doc1 = Document(page_content=f.read(), metadata={"source": file})
|
632 |
add_meta(doc1, file)
|
633 |
+
elif file.lower().endswith('.pdf'):
|
634 |
+
env_gpt4all_file = ".env_gpt4all"
|
635 |
+
from dotenv import dotenv_values
|
636 |
+
env_kwargs = dotenv_values(env_gpt4all_file)
|
637 |
+
pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
|
638 |
+
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
|
639 |
+
# GPL, only use if installed
|
640 |
+
from langchain.document_loaders import PyMuPDFLoader
|
641 |
+
doc1 = PyMuPDFLoader(file).load_and_split()
|
642 |
+
else:
|
643 |
+
# open-source fallback
|
644 |
+
doc1 = PyPDFLoader(file).load_and_split()
|
645 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
|
|
|
|
646 |
add_meta(doc1, file)
|
647 |
+
elif file.lower().endswith('.csv'):
|
648 |
doc1 = CSVLoader(file).load()
|
649 |
add_meta(doc1, file)
|
650 |
+
elif file.lower().endswith('.py'):
|
651 |
doc1 = PythonLoader(file).load()
|
652 |
add_meta(doc1, file)
|
653 |
+
elif file.lower().endswith('.toml'):
|
654 |
doc1 = TomlLoader(file).load()
|
655 |
add_meta(doc1, file)
|
656 |
+
elif file.lower().endswith('.urls'):
|
657 |
with open(file, "r") as f:
|
658 |
docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
|
659 |
add_meta(docs1, file)
|
660 |
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
661 |
+
elif file.lower().endswith('.zip'):
|
662 |
with zipfile.ZipFile(file, 'r') as zip_ref:
|
663 |
# don't put into temporary path, since want to keep references to docs inside zip
|
664 |
# so just extract in path where
|
|
|
669 |
raise RuntimeError("No file handler for %s" % os.path.basename(file))
|
670 |
|
671 |
# allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
|
672 |
+
# if list of length one, don't trust and chunk it
|
673 |
if not isinstance(doc1, list):
|
674 |
if chunk:
|
675 |
docs = chunk_sources([doc1], chunk_size=chunk_size)
|
676 |
else:
|
677 |
docs = [doc1]
|
678 |
+
elif isinstance(doc1, list) and len(doc1) == 1:
|
679 |
+
if chunk:
|
680 |
+
docs = chunk_sources(doc1, chunk_size=chunk_size)
|
681 |
+
else:
|
682 |
+
docs = doc1
|
683 |
else:
|
684 |
docs = doc1
|
685 |
|
|
|
736 |
captions_model=None,
|
737 |
caption_loader=None,
|
738 |
enable_ocr=False,
|
739 |
+
existing_files=[],
|
740 |
+
existing_hash_ids={},
|
741 |
):
|
742 |
globs_image_types = []
|
743 |
globs_non_image_types = []
|
|
|
765 |
# But instead, allow fail so can collect unsupported too
|
766 |
set_globs_image_types = set(globs_image_types)
|
767 |
globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
|
768 |
+
|
769 |
+
# filter out any files to skip (e.g. if already processed them)
|
770 |
+
# this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[]
|
771 |
+
assert not existing_files, "DEV: assume not using this approach"
|
772 |
+
if existing_files:
|
773 |
+
set_skip_files = set(existing_files)
|
774 |
+
globs_image_types = [x for x in globs_image_types if x not in set_skip_files]
|
775 |
+
globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files]
|
776 |
+
if existing_hash_ids:
|
777 |
+
# assume consistent with add_meta() use of hash_file(file)
|
778 |
+
# also assume consistent with get_existing_hash_ids for dict creation
|
779 |
+
# assume hashable values
|
780 |
+
existing_hash_ids_set = set(existing_hash_ids.items())
|
781 |
+
hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items())
|
782 |
+
hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items())
|
783 |
+
# don't use symmetric diff. If file is gone, ignore and don't remove or something
|
784 |
+
# just consider existing files (key) having new hash or not (value)
|
785 |
+
new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys())
|
786 |
+
new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys())
|
787 |
+
globs_image_types = [x for x in globs_image_types if x in new_files_image]
|
788 |
+
globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image]
|
789 |
+
|
790 |
# could use generator, but messes up metadata handling in recursive case
|
791 |
if caption_loader and not isinstance(caption_loader, (bool, str)) and \
|
792 |
caption_loader.device != 'cpu' or \
|
|
|
813 |
if n_jobs != 1 and len(globs_non_image_types) > 1:
|
814 |
# avoid nesting, e.g. upload 1 zip and then inside many files
|
815 |
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
|
816 |
+
documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
|
817 |
delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
|
818 |
)
|
819 |
else:
|
820 |
+
documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_non_image_types)]
|
821 |
|
822 |
# do images separately since can't fork after cuda in parent, so can't be parallel
|
823 |
if n_jobs_image != 1 and len(globs_image_types) > 1:
|
824 |
# avoid nesting, e.g. upload 1 zip and then inside many files
|
825 |
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
|
826 |
+
image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
|
827 |
delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
|
828 |
)
|
829 |
else:
|
830 |
+
image_documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_image_types)]
|
831 |
|
832 |
# add image docs in
|
833 |
documents += image_documents
|
|
|
846 |
return documents
|
847 |
|
848 |
|
849 |
+
def prep_langchain(persist_directory,
|
850 |
+
load_db_if_exists,
|
851 |
+
db_type, use_openai_embedding, langchain_mode, user_path,
|
852 |
hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
|
853 |
"""
|
854 |
do prep first time, involving downloads
|
|
|
857 |
"""
|
858 |
assert langchain_mode not in ['MyData'], "Should not prep scratch data"
|
859 |
|
860 |
+
db_dir_exists = os.path.isdir(persist_directory)
|
861 |
+
|
862 |
+
if db_dir_exists and user_path is None:
|
863 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
864 |
db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
865 |
hf_embedding_model)
|
866 |
else:
|
867 |
+
if db_dir_exists and user_path is not None:
|
868 |
+
print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
|
869 |
+
persist_directory, user_path), flush=True)
|
870 |
+
elif not db_dir_exists:
|
871 |
+
print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
|
872 |
db = None
|
873 |
if langchain_mode in ['All', 'DriverlessAI docs']:
|
874 |
# FIXME: Could also just use dai_docs.pickle directly and upload that
|
|
|
879 |
|
880 |
langchain_kwargs = kwargs_make_db.copy()
|
881 |
langchain_kwargs.update(locals())
|
882 |
+
db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs)
|
883 |
|
884 |
return db
|
885 |
|
886 |
|
887 |
+
import posthog
|
888 |
+
|
889 |
+
posthog.disabled = True
|
890 |
+
|
891 |
+
|
892 |
+
class FakeConsumer(object):
|
893 |
+
def __init__(self, *args, **kwargs):
|
894 |
+
pass
|
895 |
+
|
896 |
+
def run(self):
|
897 |
+
pass
|
898 |
+
|
899 |
+
def pause(self):
|
900 |
+
pass
|
901 |
+
|
902 |
+
def upload(self):
|
903 |
+
pass
|
904 |
+
|
905 |
+
def next(self):
|
906 |
+
pass
|
907 |
+
|
908 |
+
def request(self, batch):
|
909 |
+
pass
|
910 |
+
|
911 |
+
|
912 |
+
posthog.Consumer = FakeConsumer
|
913 |
+
|
914 |
+
|
915 |
def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
916 |
hf_embedding_model):
|
917 |
if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
918 |
os.path.join(persist_directory, 'index')):
|
919 |
print("DO Loading db: %s" % langchain_mode, flush=True)
|
920 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
921 |
+
from chromadb.config import Settings
|
922 |
+
client_settings = Settings(anonymized_telemetry=False,
|
923 |
+
chroma_db_impl="duckdb+parquet",
|
924 |
+
persist_directory=persist_directory)
|
925 |
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
926 |
+
collection_name=langchain_mode.replace(' ', '_'),
|
927 |
+
client_settings=client_settings)
|
928 |
print("DONE Loading db: %s" % langchain_mode, flush=True)
|
929 |
return db
|
930 |
return None
|
|
|
951 |
langchain_mode=None,
|
952 |
user_path=None,
|
953 |
db_type='faiss',
|
954 |
+
load_db_if_exists=True,
|
955 |
db=None,
|
956 |
+
n_jobs=-1,
|
957 |
+
verbose=False):
|
958 |
persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
|
959 |
if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
960 |
os.path.join(persist_directory, 'index')):
|
961 |
assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
|
962 |
+
print("Loading existing db", flush=True)
|
963 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
964 |
+
from chromadb.config import Settings
|
965 |
+
client_settings = Settings(anonymized_telemetry=False,
|
966 |
+
chroma_db_impl="duckdb+parquet",
|
967 |
+
persist_directory=persist_directory)
|
968 |
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
969 |
+
collection_name=langchain_mode.replace(' ', '_'),
|
970 |
+
client_settings=client_settings)
|
971 |
+
sources = []
|
972 |
+
if not db and langchain_mode not in ['MyData'] or \
|
973 |
+
user_path is not None and \
|
974 |
+
langchain_mode in ['UserData']:
|
975 |
+
# Should not make MyData db this way, why avoided, only upload from UI
|
976 |
assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
|
977 |
+
if verbose:
|
978 |
+
if langchain_mode in ['UserData']:
|
979 |
+
if user_path is not None:
|
980 |
+
print("Checking if changed or new sources in %s, and generating sources them" % user_path,
|
981 |
+
flush=True)
|
982 |
+
elif db is None:
|
983 |
+
print("user_path not passed and no db, no sources", flush=True)
|
984 |
+
else:
|
985 |
+
print("user_path not passed, using only existing db, no new sources", flush=True)
|
986 |
+
else:
|
987 |
+
print("Generating %s sources" % langchain_mode, flush=True)
|
988 |
if langchain_mode in ['wiki_full', 'All', "'All'"]:
|
989 |
from read_wiki_full import get_all_documents
|
990 |
small_test = None
|
|
|
1013 |
sources.extend(sources1)
|
1014 |
if langchain_mode in ['All', 'UserData']:
|
1015 |
if user_path:
|
1016 |
+
if db is not None:
|
1017 |
+
# NOTE: Ignore file names for now, only go by hash ids
|
1018 |
+
# existing_files = get_existing_files(db)
|
1019 |
+
existing_files = []
|
1020 |
+
existing_hash_ids = get_existing_hash_ids(db)
|
1021 |
+
else:
|
1022 |
+
# pretend no existing files so won't filter
|
1023 |
+
existing_files = []
|
1024 |
+
existing_hash_ids = []
|
1025 |
# chunk internally for speed over multiple docs
|
1026 |
+
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
|
1027 |
+
existing_files=existing_files, existing_hash_ids=existing_hash_ids)
|
1028 |
+
new_metadata_sources = set([x.metadata['source'] for x in sources1])
|
1029 |
+
if new_metadata_sources:
|
1030 |
+
print("Loaded %s new files as sources to add to UserData" % len(new_metadata_sources), flush=True)
|
1031 |
+
if verbose:
|
1032 |
+
print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
|
1033 |
sources.extend(sources1)
|
1034 |
+
print("Loaded %s sources for potentially adding to UserData" % len(sources), flush=True)
|
1035 |
else:
|
1036 |
print("Chose UserData but user_path is empty/None", flush=True)
|
1037 |
if False and langchain_mode in ['urls', 'All', "'All'"]:
|
|
|
1043 |
sources1 = loader.load()
|
1044 |
sources.extend(sources1)
|
1045 |
if not sources:
|
1046 |
+
if verbose:
|
1047 |
+
if db is not None:
|
1048 |
+
print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True)
|
1049 |
+
else:
|
1050 |
+
print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True)
|
1051 |
+
return db, 0, []
|
1052 |
+
if verbose:
|
1053 |
+
if db is not None:
|
1054 |
+
print("Generating db", flush=True)
|
1055 |
+
else:
|
1056 |
+
print("Adding to db", flush=True)
|
1057 |
+
if not db:
|
1058 |
+
if sources:
|
1059 |
+
db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
|
1060 |
+
persist_directory=persist_directory, langchain_mode=langchain_mode,
|
1061 |
+
hf_embedding_model=hf_embedding_model)
|
1062 |
+
if verbose:
|
1063 |
+
print("Generated db", flush=True)
|
1064 |
+
else:
|
1065 |
+
print("Did not generate db since no sources", flush=True)
|
1066 |
+
new_sources_metadata = [x.metadata for x in sources]
|
1067 |
+
elif user_path is not None and langchain_mode in ['UserData']:
|
1068 |
+
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
|
1069 |
+
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type)
|
1070 |
+
print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
|
1071 |
+
else:
|
1072 |
+
new_sources_metadata = [x.metadata for x in sources]
|
1073 |
+
|
1074 |
+
return db, len(new_sources_metadata), new_sources_metadata
|
1075 |
+
|
1076 |
+
|
1077 |
+
def get_existing_files(db):
|
1078 |
+
collection = db.get()
|
1079 |
+
metadata_sources = set([x['source'] for x in collection['metadatas']])
|
1080 |
+
return metadata_sources
|
1081 |
+
|
1082 |
+
|
1083 |
+
def get_existing_hash_ids(db):
|
1084 |
+
collection = db.get()
|
1085 |
+
# assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
|
1086 |
+
metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']}
|
1087 |
+
return metadata_hash_ids
|
1088 |
|
1089 |
|
1090 |
source_prefix = "Sources [Score | Link]:"
|
|
|
1108 |
use_openai_model=False, use_openai_embedding=False,
|
1109 |
first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
|
1110 |
user_path=None,
|
1111 |
+
detect_user_path_changes_every_query=False,
|
1112 |
db_type='faiss',
|
1113 |
model_name=None, model=None, tokenizer=None,
|
1114 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
|
|
1128 |
top_p=0.7,
|
1129 |
langchain_mode=None,
|
1130 |
document_choice=['All'],
|
1131 |
+
n_jobs=-1,
|
1132 |
+
verbose=False,
|
1133 |
+
cli=False):
|
1134 |
"""
|
1135 |
|
1136 |
:param query:
|
|
|
1142 |
:param chunk:
|
1143 |
:param chunk_size:
|
1144 |
:param user_path: user path to glob recursively from
|
1145 |
+
:param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
|
1146 |
:param model_name: model name, used to switch behaviors
|
1147 |
:param model: pre-initialized model, else will make new one
|
1148 |
:param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
|
1149 |
:param answer_with_sources
|
1150 |
:return:
|
1151 |
"""
|
1152 |
+
assert query is not None
|
1153 |
+
assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
|
1154 |
+
if prompter is not None:
|
1155 |
+
prompt_type = prompter.prompt_type
|
1156 |
+
if model is not None:
|
1157 |
+
assert prompt_type is not None
|
1158 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1159 |
model=model, tokenizer=tokenizer,
|
1160 |
stream_output=stream_output,
|
|
|
1164 |
top_k=top_k,
|
1165 |
top_p=top_p,
|
1166 |
prompt_type=prompt_type,
|
1167 |
+
prompter=prompter,
|
1168 |
+
verbose=verbose,
|
1169 |
)
|
1170 |
|
1171 |
+
if model_name in non_hf_types:
|
1172 |
# FIXME: for now, streams to stdout/stderr currently
|
1173 |
stream_output = False
|
1174 |
|
1175 |
+
use_context = False
|
1176 |
+
scores = []
|
1177 |
+
chain = None
|
1178 |
+
|
1179 |
+
func_names = list(inspect.signature(get_similarity_chain).parameters)
|
1180 |
+
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1181 |
+
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1182 |
+
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1183 |
+
docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
|
1184 |
+
if len(document_choice) > 0 and document_choice[0] == 'Only':
|
1185 |
+
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1186 |
+
yield formatted_doc_chunks, ''
|
1187 |
+
return
|
1188 |
+
if chain is None and model_name not in non_hf_types:
|
1189 |
+
# can only return if HF type
|
1190 |
+
return
|
1191 |
+
|
1192 |
+
if stream_output:
|
1193 |
+
answer = None
|
1194 |
+
assert streamer is not None
|
1195 |
+
import queue
|
1196 |
+
bucket = queue.Queue()
|
1197 |
+
thread = EThread(target=chain, streamer=streamer, bucket=bucket)
|
1198 |
+
thread.start()
|
1199 |
+
outputs = ""
|
1200 |
+
prompt = None # FIXME
|
1201 |
+
try:
|
1202 |
+
for new_text in streamer:
|
1203 |
+
# print("new_text: %s" % new_text, flush=True)
|
1204 |
+
if bucket.qsize() > 0 or thread.exc:
|
1205 |
+
thread.join()
|
1206 |
+
outputs += new_text
|
1207 |
+
if prompter: # and False: # FIXME: pipeline can already use prompter
|
1208 |
+
output1 = prompter.get_response(outputs, prompt=prompt,
|
1209 |
+
sanitize_bot_response=sanitize_bot_response)
|
1210 |
+
yield output1, ''
|
1211 |
+
else:
|
1212 |
+
yield outputs, ''
|
1213 |
+
except BaseException:
|
1214 |
+
# if any exception, raise that exception if was from thread, first
|
1215 |
+
if thread.exc:
|
1216 |
+
raise thread.exc
|
1217 |
+
raise
|
1218 |
+
finally:
|
1219 |
+
# in case no exception and didn't join with thread yet, then join
|
1220 |
+
if not thread.exc:
|
1221 |
+
answer = thread.join()
|
1222 |
+
# in case raise StopIteration or broke queue loop in streamer, but still have exception
|
1223 |
+
if thread.exc:
|
1224 |
+
raise thread.exc
|
1225 |
+
# FIXME: answer is not string outputs from streamer. How to get actual final output?
|
1226 |
+
# answer = outputs
|
1227 |
+
else:
|
1228 |
+
answer = chain()
|
1229 |
+
|
1230 |
+
if not use_context:
|
1231 |
+
ret = answer['output_text']
|
1232 |
+
extra = ''
|
1233 |
+
yield ret, extra
|
1234 |
+
elif answer is not None:
|
1235 |
+
ret, extra = get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=verbose)
|
1236 |
+
yield ret, extra
|
1237 |
+
return
|
1238 |
+
|
1239 |
+
|
1240 |
+
def get_similarity_chain(query=None,
|
1241 |
+
use_openai_model=False, use_openai_embedding=False,
|
1242 |
+
first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
|
1243 |
+
user_path=None,
|
1244 |
+
detect_user_path_changes_every_query=False,
|
1245 |
+
db_type='faiss',
|
1246 |
+
model_name=None,
|
1247 |
+
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1248 |
+
prompt_type=None,
|
1249 |
+
cut_distanct=1.1,
|
1250 |
+
load_db_if_exists=False,
|
1251 |
+
db=None,
|
1252 |
+
langchain_mode=None,
|
1253 |
+
document_choice=['All'],
|
1254 |
+
n_jobs=-1,
|
1255 |
+
# beyond run_db_query:
|
1256 |
+
llm=None,
|
1257 |
+
verbose=False,
|
1258 |
+
):
|
1259 |
+
# determine whether use of context out of docs is planned
|
1260 |
+
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
1261 |
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
1262 |
use_context = False
|
|
|
1263 |
else:
|
1264 |
use_context = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1265 |
else:
|
|
|
1266 |
use_context = True
|
1267 |
|
|
|
|
|
1268 |
# https://github.com/hwchase17/langchain/issues/1946
|
1269 |
# FIXME: Seems to way to get size of chroma db to limit k to avoid
|
1270 |
# Chroma collection MyData contains fewer than 4 elements.
|
1271 |
# type logger error
|
1272 |
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
|
1273 |
|
1274 |
+
# FIXME: For All just go over all dbs instead of a separate db for All
|
1275 |
+
if not detect_user_path_changes_every_query and db is not None:
|
1276 |
+
# avoid looking at user_path during similarity search db handling,
|
1277 |
+
# if already have db and not updating from user_path every query
|
1278 |
+
# but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
|
1279 |
+
user_path = None
|
1280 |
+
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
|
1281 |
+
hf_embedding_model=hf_embedding_model,
|
1282 |
+
first_para=first_para, text_limit=text_limit, chunk=chunk,
|
1283 |
+
chunk_size=chunk_size,
|
1284 |
+
langchain_mode=langchain_mode,
|
1285 |
+
user_path=user_path,
|
1286 |
+
db_type=db_type,
|
1287 |
+
load_db_if_exists=load_db_if_exists,
|
1288 |
+
db=db,
|
1289 |
+
n_jobs=n_jobs,
|
1290 |
+
verbose=verbose)
|
1291 |
+
|
1292 |
if db and use_context:
|
1293 |
if isinstance(document_choice, str):
|
1294 |
# support string as well
|
1295 |
document_choice = [document_choice]
|
1296 |
+
if not isinstance(db, Chroma) or \
|
1297 |
+
len(document_choice) == 0 or \
|
1298 |
+
len(document_choice) <= 1 and document_choice[0] == 'All':
|
1299 |
# treat empty list as All for now, not 'None'
|
1300 |
filter_kwargs = {}
|
1301 |
+
elif len(document_choice) > 0 and document_choice[0] == 'Only':
|
1302 |
+
# Only means All docs, but only will return sources, not LLM response
|
1303 |
+
filter_kwargs = {}
|
1304 |
else:
|
1305 |
if len(document_choice) >= 2:
|
1306 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
1307 |
filter_kwargs = dict(filter={"$or": or_filter})
|
1308 |
+
elif len(document_choice) > 0:
|
1309 |
one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
|
1310 |
filter_kwargs = dict(filter=one_filter)
|
1311 |
+
else:
|
1312 |
+
filter_kwargs = {}
|
1313 |
+
if len(document_choice) == 1 and document_choice[0] == 'None':
|
1314 |
k_db = 1
|
1315 |
k = 0
|
1316 |
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
|
1317 |
# cut off so no high distance docs/sources considered
|
1318 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
1319 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
1320 |
+
if len(scores) > 0 and verbose:
|
1321 |
print("Distance: min: %s max: %s mean: %s median: %s" %
|
1322 |
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
1323 |
else:
|
1324 |
docs = []
|
1325 |
scores = []
|
1326 |
|
1327 |
+
if not docs and use_context and model_name not in non_hf_types:
|
1328 |
+
# if HF type and have no docs, can bail out
|
1329 |
+
return docs, None, [], False
|
1330 |
+
|
1331 |
+
if len(document_choice) > 0 and document_choice[0] == 'Only':
|
1332 |
+
# no LLM use
|
1333 |
+
return docs, None, [], False
|
1334 |
|
1335 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
1336 |
if os.path.isfile(common_words_file):
|
|
|
1342 |
num_common = len([x.lower() in set_common for x in reduced_query_words])
|
1343 |
frac_common = num_common / len(reduced_query) if reduced_query else 0
|
1344 |
# FIXME: report to user bad query that uses too many common words
|
1345 |
+
if verbose:
|
1346 |
+
print("frac_common: %s" % frac_common, flush=True)
|
1347 |
+
|
1348 |
+
if len(docs) == 0:
|
1349 |
+
# avoid context == in prompt then
|
1350 |
+
use_context = False
|
1351 |
|
1352 |
+
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
1353 |
+
# instruct-like, rather than few-shot prompt_type='plain' as default
|
1354 |
+
# but then sources confuse the model with how inserted among rest of text, so avoid
|
1355 |
+
prefix = ""
|
1356 |
+
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
|
1357 |
+
template = """%s{context}{question}""" % prefix
|
1358 |
+
else:
|
1359 |
+
template = """%s
|
1360 |
+
==
|
1361 |
+
{context}
|
1362 |
+
==
|
1363 |
+
{question}""" % prefix
|
1364 |
+
prompt = PromptTemplate(
|
1365 |
+
# input_variables=["summaries", "question"],
|
1366 |
+
input_variables=["context", "question"],
|
1367 |
+
template=template,
|
1368 |
+
)
|
1369 |
+
chain = load_qa_chain(llm, prompt=prompt)
|
1370 |
+
else:
|
1371 |
+
chain = load_qa_with_sources_chain(llm)
|
1372 |
+
|
1373 |
+
if not use_context:
|
1374 |
chain_kwargs = dict(input_documents=[], question=query)
|
1375 |
else:
|
1376 |
chain_kwargs = dict(input_documents=docs, question=query)
|
1377 |
|
1378 |
+
target = wrapped_partial(chain, chain_kwargs)
|
1379 |
+
return docs, target, scores, use_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1380 |
|
1381 |
+
|
1382 |
+
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
1383 |
+
if verbose:
|
|
|
1384 |
print("query: %s" % query, flush=True)
|
1385 |
print("answer: %s" % answer['output_text'], flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1386 |
|
1387 |
+
if len(answer['input_documents']) == 0:
|
1388 |
+
extra = ''
|
1389 |
+
ret = answer['output_text'] + extra
|
1390 |
+
return ret, extra
|
1391 |
+
|
1392 |
+
# link
|
1393 |
+
answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
|
1394 |
+
zip(scores, answer['input_documents'])]
|
1395 |
+
answer_sources_dict = defaultdict(list)
|
1396 |
+
[answer_sources_dict[url].append(score) for score, url in answer_sources]
|
1397 |
+
answers_dict = {}
|
1398 |
+
for url, scores_url in answer_sources_dict.items():
|
1399 |
+
answers_dict[url] = np.max(scores_url)
|
1400 |
+
answer_sources = [(score, url) for url, score in answers_dict.items()]
|
1401 |
+
answer_sources.sort(key=lambda x: x[0], reverse=True)
|
1402 |
+
if show_rank:
|
1403 |
+
# answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
|
1404 |
+
# sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
|
1405 |
+
answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
|
1406 |
+
sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
|
1407 |
+
else:
|
1408 |
+
answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
|
1409 |
+
sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
|
1410 |
+
sorted_sources_urls += f"</ul></p>{source_postfix}"
|
1411 |
|
1412 |
+
if not answer['output_text'].endswith('\n'):
|
1413 |
+
answer['output_text'] += '\n'
|
|
|
|
|
1414 |
|
1415 |
+
if answer_with_sources:
|
1416 |
+
extra = '\n' + sorted_sources_urls
|
1417 |
+
else:
|
1418 |
+
extra = ''
|
1419 |
+
ret = answer['output_text'] + extra
|
1420 |
+
return ret, extra
|
1421 |
|
1422 |
|
1423 |
def chunk_sources(sources, chunk_size=1024):
|
gradio_runner.py
CHANGED
@@ -9,17 +9,33 @@ import traceback
|
|
9 |
import uuid
|
10 |
import filelock
|
11 |
import pandas as pd
|
|
|
12 |
import tabulate
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
|
15 |
from prompter import Prompter, \
|
16 |
-
prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt
|
17 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
18 |
-
ping, get_short_name, get_url, makedirs
|
19 |
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
|
20 |
inputs_kwargs_list, get_cutoffs, scratch_base_dir
|
21 |
|
22 |
-
import gradio as gr
|
23 |
from apscheduler.schedulers.background import BackgroundScheduler
|
24 |
|
25 |
|
@@ -27,12 +43,11 @@ def go_gradio(**kwargs):
|
|
27 |
allow_api = kwargs['allow_api']
|
28 |
is_public = kwargs['is_public']
|
29 |
is_hf = kwargs['is_hf']
|
30 |
-
|
31 |
n_gpus = kwargs['n_gpus']
|
32 |
admin_pass = kwargs['admin_pass']
|
33 |
model_state0 = kwargs['model_state0']
|
34 |
score_model_state0 = kwargs['score_model_state0']
|
35 |
-
queue = True
|
36 |
dbs = kwargs['dbs']
|
37 |
db_type = kwargs['db_type']
|
38 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
@@ -41,7 +56,6 @@ def go_gradio(**kwargs):
|
|
41 |
enable_sources_list = kwargs['enable_sources_list']
|
42 |
enable_url_upload = kwargs['enable_url_upload']
|
43 |
enable_text_upload = kwargs['enable_text_upload']
|
44 |
-
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
45 |
use_openai_embedding = kwargs['use_openai_embedding']
|
46 |
hf_embedding_model = kwargs['hf_embedding_model']
|
47 |
enable_captions = kwargs['enable_captions']
|
@@ -50,6 +64,8 @@ def go_gradio(**kwargs):
|
|
50 |
caption_loader = kwargs['caption_loader']
|
51 |
|
52 |
# easy update of kwargs needed for evaluate() etc.
|
|
|
|
|
53 |
kwargs.update(locals())
|
54 |
|
55 |
if 'mbart-' in kwargs['model_lower']:
|
@@ -76,8 +92,8 @@ def go_gradio(**kwargs):
|
|
76 |
"""
|
77 |
else:
|
78 |
description = more_info
|
79 |
-
description += "If this host is busy, try [12B](https://gpt.h2o.ai), [
|
80 |
-
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
|
81 |
if is_hf:
|
82 |
description += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
83 |
|
@@ -95,6 +111,7 @@ def go_gradio(**kwargs):
|
|
95 |
else:
|
96 |
css_code = """footer {visibility: hidden}"""
|
97 |
css_code += """
|
|
|
98 |
body.dark{#warning {background-color: #555555};}
|
99 |
#small_btn {
|
100 |
margin: 0.6em 0em 0.55em 0;
|
@@ -131,7 +148,19 @@ body.dark{#warning {background-color: #555555};}
|
|
131 |
|
132 |
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
|
133 |
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
|
136 |
callback = gr.CSVLogger()
|
137 |
|
@@ -173,7 +202,11 @@ body.dark{#warning {background-color: #555555};}
|
|
173 |
lora_options_state = gr.State([lora_options])
|
174 |
my_db_state = gr.State([None, None])
|
175 |
chat_state = gr.State({})
|
176 |
-
|
|
|
|
|
|
|
|
|
177 |
gr.Markdown(f"""
|
178 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
179 |
|
@@ -258,10 +291,10 @@ body.dark{#warning {background-color: #555555};}
|
|
258 |
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
259 |
type='value')
|
260 |
with gr.Row():
|
261 |
-
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
|
262 |
-
export_chats_btn = gr.Button(value="Export Chats to Download")
|
263 |
-
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
|
264 |
-
add_to_chats_btn = gr.Button("Import Chats from Upload")
|
265 |
with gr.Row():
|
266 |
chats_file = gr.File(interactive=False, label="Download Exported Chats")
|
267 |
chatsup_output = gr.File(label="Upload Chat File(s)",
|
@@ -269,7 +302,7 @@ body.dark{#warning {background-color: #555555};}
|
|
269 |
file_count='multiple',
|
270 |
elem_id="warning", elem_classes="feedback")
|
271 |
with gr.TabItem("Data Source"):
|
272 |
-
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
|
273 |
from_str=True)
|
274 |
gr.HTML(value=f"""LangChain Support Disabled<p>
|
275 |
Run:<p>
|
@@ -302,7 +335,7 @@ body.dark{#warning {background-color: #555555};}
|
|
302 |
with data_row2:
|
303 |
with gr.Column(scale=50):
|
304 |
document_choice = gr.Dropdown(docs_state.value,
|
305 |
-
label="Choose Subset of Doc(s) in Collection [click get to update]",
|
306 |
value=docs_state.value[0],
|
307 |
interactive=True,
|
308 |
multiselect=True,
|
@@ -312,6 +345,8 @@ body.dark{#warning {background-color: #555555};}
|
|
312 |
).style(full_width=False, size='sm')
|
313 |
show_sources_btn = gr.Button(value="Show Sources",
|
314 |
).style(full_width=False, size='sm')
|
|
|
|
|
315 |
|
316 |
# import control
|
317 |
if kwargs['langchain_mode'] != 'Disabled':
|
@@ -375,7 +410,7 @@ body.dark{#warning {background-color: #555555};}
|
|
375 |
with sources_row3:
|
376 |
with gr.Column(scale=1):
|
377 |
file_source = gr.File(interactive=False,
|
378 |
-
label="Download File
|
379 |
with gr.Column(scale=2):
|
380 |
pass
|
381 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
@@ -411,14 +446,24 @@ body.dark{#warning {background-color: #555555};}
|
|
411 |
)
|
412 |
# FIXME: https://github.com/h2oai/h2ogpt/issues/106
|
413 |
if os.getenv('TESTINGFAIL'):
|
414 |
-
max_beams = 8 if not (
|
415 |
else:
|
416 |
max_beams = 1
|
417 |
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
418 |
value=min(max_beams, kwargs['num_beams']), label="Beams",
|
419 |
info="Number of searches for optimal overall probability. "
|
420 |
"Uses more GPU memory/compute")
|
421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
max_new_tokens = gr.Slider(
|
423 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
424 |
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
|
@@ -450,11 +495,19 @@ body.dark{#warning {background-color: #555555};}
|
|
450 |
visible=not is_public)
|
451 |
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
|
452 |
visible=not is_public)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
with gr.TabItem("Models"):
|
455 |
-
load_msg = "Load-Unload Model/LORA" if not is_public \
|
456 |
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
|
457 |
-
load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
|
458 |
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
|
459 |
compare_checkbox = gr.components.Checkbox(label="Compare Mode",
|
460 |
value=False, visible=not is_public)
|
@@ -468,7 +521,7 @@ body.dark{#warning {background-color: #555555};}
|
|
468 |
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
|
469 |
value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
470 |
with gr.Column(scale=1):
|
471 |
-
load_model_button = gr.Button(load_msg)
|
472 |
model_load8bit_checkbox = gr.components.Checkbox(
|
473 |
label="Load 8-bit [requires support]",
|
474 |
value=kwargs['load_8bit'])
|
@@ -476,19 +529,12 @@ body.dark{#warning {background-color: #555555};}
|
|
476 |
label="Choose Devices [If not Checked, use all GPUs]",
|
477 |
value=kwargs['infer_devices'])
|
478 |
model_gpu = gr.Dropdown(n_gpus_list,
|
479 |
-
label="GPU ID
|
480 |
value=kwargs['gpu_id'])
|
481 |
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
|
482 |
interactive=False)
|
483 |
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
|
484 |
visible=kwargs['show_lora'], interactive=False)
|
485 |
-
with gr.Row():
|
486 |
-
with gr.Column(scale=50):
|
487 |
-
new_model = gr.Textbox(label="New Model HF name/path")
|
488 |
-
new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
|
489 |
-
with gr.Column(scale=1):
|
490 |
-
add_model_button = gr.Button("Add new model name")
|
491 |
-
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
492 |
col_model2 = gr.Column(visible=False)
|
493 |
with col_model2:
|
494 |
with gr.Row():
|
@@ -499,7 +545,7 @@ body.dark{#warning {background-color: #555555};}
|
|
499 |
value=no_lora_str,
|
500 |
visible=kwargs['show_lora'])
|
501 |
with gr.Column(scale=1):
|
502 |
-
load_model_button2 = gr.Button(load_msg2)
|
503 |
model_load8bit_checkbox2 = gr.components.Checkbox(
|
504 |
label="Load 8-bit 2 [requires support]",
|
505 |
value=kwargs['load_8bit'])
|
@@ -508,12 +554,22 @@ body.dark{#warning {background-color: #555555};}
|
|
508 |
value=kwargs[
|
509 |
'infer_devices'])
|
510 |
model_gpu2 = gr.Dropdown(n_gpus_list,
|
511 |
-
label="GPU ID [-1 = all GPUs, if choose is enabled]",
|
512 |
value=kwargs['gpu_id'])
|
513 |
# no model/lora loaded ever in model2 by default
|
514 |
model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
|
515 |
lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
|
516 |
visible=kwargs['show_lora'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
with gr.TabItem("System"):
|
518 |
admin_row = gr.Row()
|
519 |
with admin_row:
|
@@ -530,7 +586,7 @@ body.dark{#warning {background-color: #555555};}
|
|
530 |
with gr.Row():
|
531 |
zip_btn = gr.Button("Zip")
|
532 |
zip_text = gr.Textbox(label="Zip file name", interactive=False)
|
533 |
-
file_output = gr.File(interactive=False)
|
534 |
with gr.Row():
|
535 |
s3up_btn = gr.Button("S3UP")
|
536 |
s3up_text = gr.Textbox(label='S3UP result', interactive=False)
|
@@ -542,7 +598,7 @@ body.dark{#warning {background-color: #555555};}
|
|
542 |
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
543 |
if 'h2ogpt-research' in kwargs['base_model']:
|
544 |
description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
|
545 |
-
description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/tos.md">Terms of Service</a></i></li></ul></p>"""
|
546 |
gr.Markdown(value=description, show_label=False, interactive=False)
|
547 |
|
548 |
# Get flagged data
|
@@ -633,24 +689,37 @@ body.dark{#warning {background-color: #555555};}
|
|
633 |
api_name='add_txt_to_my' if allow_api else None) \
|
634 |
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
635 |
|
636 |
-
get_sources1 = functools.partial(get_sources, dbs=dbs)
|
637 |
|
638 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
639 |
def clear_doc_choice():
|
640 |
-
return gr.Dropdown.update(choices=
|
641 |
|
642 |
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
|
643 |
|
644 |
def update_dropdown(x):
|
645 |
-
return gr.Dropdown.update(choices=x, value=
|
646 |
|
647 |
-
show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
|
648 |
get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
|
649 |
queue=queue,
|
650 |
api_name='get_sources' if allow_api else None) \
|
651 |
.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
|
652 |
# show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
|
653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
|
655 |
def check_admin_pass(x):
|
656 |
return gr.update(visible=x == admin_pass)
|
@@ -661,10 +730,6 @@ body.dark{#warning {background-color: #555555};}
|
|
661 |
admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
|
662 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
663 |
|
664 |
-
# Get inputs to evaluate()
|
665 |
-
# don't deepcopy, can contain model itself
|
666 |
-
all_kwargs = kwargs.copy()
|
667 |
-
all_kwargs.update(locals())
|
668 |
inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
|
669 |
from functools import partial
|
670 |
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
|
@@ -714,7 +779,10 @@ body.dark{#warning {background-color: #555555};}
|
|
714 |
""" Similar to user() """
|
715 |
args_list = list(args)
|
716 |
|
717 |
-
|
|
|
|
|
|
|
718 |
cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
|
719 |
smodel = score_model_state0[0]
|
720 |
stokenizer = score_model_state0[1]
|
@@ -811,6 +879,8 @@ body.dark{#warning {background-color: #555555};}
|
|
811 |
# e.g. when user just hits enter in textbox,
|
812 |
# else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
|
813 |
user_message1 = '\n'
|
|
|
|
|
814 |
|
815 |
history = args_list[-1]
|
816 |
if undo and history:
|
@@ -830,6 +900,43 @@ body.dark{#warning {background-color: #555555};}
|
|
830 |
# FIXME: compare, same history for now
|
831 |
return history + [[user_message1, None]]
|
832 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
def bot(*args, retry=False):
|
834 |
"""
|
835 |
bot that consumes history for user input
|
@@ -861,47 +968,15 @@ body.dark{#warning {background-color: #555555};}
|
|
861 |
history = []
|
862 |
yield history, ''
|
863 |
return
|
864 |
-
# ensure output will be unique to models
|
865 |
-
_, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
|
866 |
-
history = copy.deepcopy(history)
|
867 |
instruction1 = history[-1][0]
|
868 |
if not instruction1:
|
869 |
# reject empty query, can sometimes go nuts
|
870 |
history = []
|
871 |
yield history, ''
|
872 |
return
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
877 |
-
chat1 = args_list[eval_func_param_names.index('chat')]
|
878 |
-
context1 = ''
|
879 |
-
# - 1 below because current instruction already in history from user()
|
880 |
-
for histi in range(0, len(history) - 1):
|
881 |
-
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
882 |
-
prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
|
883 |
-
chat1, reduced=True)
|
884 |
-
# md -> back to text, maybe not super important if model trained enough
|
885 |
-
if not kwargs['keep_sources_in_context']:
|
886 |
-
from gpt_langchain import source_prefix, source_postfix
|
887 |
-
import re
|
888 |
-
prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
|
889 |
-
flags=re.DOTALL)
|
890 |
-
if prompt.endswith('\n<p>'):
|
891 |
-
prompt = prompt[:-4]
|
892 |
-
prompt = prompt.replace('<br>', chat_sep)
|
893 |
-
if not prompt.endswith(chat_sep):
|
894 |
-
prompt += chat_sep
|
895 |
-
# most recent first, add older if can
|
896 |
-
# only include desired chat history
|
897 |
-
if len(prompt + context1) > max_prompt_length:
|
898 |
-
break
|
899 |
-
context1 = prompt + context1
|
900 |
-
|
901 |
-
_, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
|
902 |
-
reduced=True)
|
903 |
-
if context1 and not context1.endswith(chat_sep):
|
904 |
-
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
905 |
args_list[0] = instruction1 # override original instruction with history from user
|
906 |
args_list[2] = context1
|
907 |
fun1 = partial(evaluate,
|
@@ -909,8 +984,11 @@ body.dark{#warning {background-color: #555555};}
|
|
909 |
my_db_state1,
|
910 |
**kwargs_evaluate)
|
911 |
try:
|
912 |
-
for
|
913 |
-
|
|
|
|
|
|
|
914 |
history[-1][1] = bot_message
|
915 |
yield history, ''
|
916 |
except StopIteration:
|
@@ -1067,11 +1145,11 @@ body.dark{#warning {background-color: #555555};}
|
|
1067 |
if len(stepy) != 2:
|
1068 |
# something off
|
1069 |
return False
|
1070 |
-
questionx = stepx[0].replace('<p>', '').replace('</p>', '')
|
1071 |
-
answerx = stepx[1].replace('<p>', '').replace('</p>', '')
|
1072 |
|
1073 |
-
questiony = stepy[0].replace('<p>', '').replace('</p>', '')
|
1074 |
-
answery = stepy[1].replace('<p>', '').replace('</p>', '')
|
1075 |
|
1076 |
if questionx != questiony or answerx != answery:
|
1077 |
return False
|
@@ -1221,7 +1299,9 @@ body.dark{#warning {background-color: #555555};}
|
|
1221 |
lora_weights = ''
|
1222 |
|
1223 |
all_kwargs1['lora_weights'] = lora_weights.strip()
|
1224 |
-
model1, tokenizer1, device1 = get_model(
|
|
|
|
|
1225 |
clear_torch_cache()
|
1226 |
|
1227 |
if kwargs['debug']:
|
@@ -1242,7 +1322,7 @@ body.dark{#warning {background-color: #555555};}
|
|
1242 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1243 |
nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
|
1244 |
if not is_public:
|
1245 |
-
load_model_event = load_model_button.click(**load_model_args) \
|
1246 |
.then(**prompt_update_args) \
|
1247 |
.then(**chatbot_update_args) \
|
1248 |
.then(**nochat_update_args) \
|
@@ -1255,7 +1335,8 @@ body.dark{#warning {background-color: #555555};}
|
|
1255 |
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
|
1256 |
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
|
1257 |
if not is_public:
|
1258 |
-
load_model_event2 = load_model_button2.click(**load_model_args2
|
|
|
1259 |
.then(**prompt_update_args2) \
|
1260 |
.then(**chatbot_update_args2) \
|
1261 |
.then(clear_torch_cache)
|
@@ -1331,6 +1412,27 @@ body.dark{#warning {background-color: #555555};}
|
|
1331 |
submit_event3d, submit_event3f,
|
1332 |
submit_event_nochat],
|
1333 |
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1334 |
demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
|
1335 |
|
1336 |
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
|
@@ -1339,7 +1441,7 @@ body.dark{#warning {background-color: #555555};}
|
|
1339 |
scheduler = BackgroundScheduler()
|
1340 |
scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
|
1341 |
if is_public and \
|
1342 |
-
kwargs['base_model'] not in
|
1343 |
# FIXME: disable for gptj, langchain or gpt4all modify print itself
|
1344 |
# FIXME: and any multi-threaded/async print will enter model output!
|
1345 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
@@ -1348,14 +1450,15 @@ body.dark{#warning {background-color: #555555};}
|
|
1348 |
# import control
|
1349 |
if kwargs['langchain_mode'] == 'Disabled' and \
|
1350 |
os.environ.get("TEST_LANGCHAIN_IMPORT") and \
|
1351 |
-
kwargs['base_model'] not in
|
1352 |
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
1353 |
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
1354 |
|
1355 |
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
1356 |
favicon_path=favicon_path, prevent_thread_lock=True,
|
1357 |
auth=kwargs['auth'])
|
1358 |
-
|
|
|
1359 |
if kwargs['block_gradio_exit']:
|
1360 |
demo.block_thread()
|
1361 |
|
@@ -1384,7 +1487,7 @@ def get_inputs_list(inputs_dict, model_lower):
|
|
1384 |
return inputs_list
|
1385 |
|
1386 |
|
1387 |
-
def get_sources(db1, langchain_mode, dbs=None):
|
1388 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
1389 |
source_files_added = "NA"
|
1390 |
source_list = []
|
@@ -1407,7 +1510,7 @@ def get_sources(db1, langchain_mode, dbs=None):
|
|
1407 |
sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
|
1408 |
with open(sources_file, "wt") as f:
|
1409 |
f.write(source_files_added)
|
1410 |
-
source_list =
|
1411 |
return sources_file, source_list
|
1412 |
|
1413 |
|
@@ -1471,7 +1574,7 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
|
|
1471 |
if langchain_mode == 'MyData':
|
1472 |
if db1[0] is not None:
|
1473 |
# then add
|
1474 |
-
add_to_db(db1[0], sources, db_type=db_type)
|
1475 |
else:
|
1476 |
assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
|
1477 |
# then create
|
@@ -1486,13 +1589,13 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
|
|
1486 |
hf_embedding_model=hf_embedding_model)
|
1487 |
if db1[0] is None:
|
1488 |
db1[1] = None
|
1489 |
-
source_files_added = get_source_files(db1[0], exceptions=exceptions)
|
1490 |
return db1, x, y, source_files_added
|
1491 |
else:
|
1492 |
persist_directory = 'db_dir_%s' % langchain_mode
|
1493 |
if langchain_mode in dbs and dbs[langchain_mode] is not None:
|
1494 |
# then add
|
1495 |
-
add_to_db(dbs[langchain_mode], sources, db_type=db_type)
|
1496 |
else:
|
1497 |
# then create
|
1498 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
@@ -1504,11 +1607,11 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
|
|
1504 |
# NOTE we do not return db, because function call always same code path
|
1505 |
# return dbs[langchain_mode], x, y
|
1506 |
# db in this code path is updated in place
|
1507 |
-
source_files_added = get_source_files(dbs[langchain_mode], exceptions=exceptions)
|
1508 |
return x, y, source_files_added
|
1509 |
|
1510 |
|
1511 |
-
def
|
1512 |
with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
|
1513 |
if langchain_mode in ['wiki_full']:
|
1514 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
@@ -1519,17 +1622,31 @@ def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=No
|
|
1519 |
db = dbs[langchain_mode]
|
1520 |
else:
|
1521 |
db = None
|
1522 |
-
return
|
|
|
|
|
|
|
|
|
|
|
1523 |
|
1524 |
|
1525 |
-
def get_source_files(db, exceptions=None):
|
1526 |
if exceptions is None:
|
1527 |
exceptions = []
|
1528 |
|
1529 |
-
|
1530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1531 |
else:
|
1532 |
-
|
|
|
1533 |
|
1534 |
# below automatically de-dups
|
1535 |
from gpt_langchain import get_url
|
@@ -1558,28 +1675,28 @@ def get_source_files(db, exceptions=None):
|
|
1558 |
<html>
|
1559 |
<body>
|
1560 |
<p>
|
1561 |
-
|
1562 |
</p>
|
1563 |
<div style="overflow-y: auto;height:400px">
|
1564 |
-
{0}
|
1565 |
{1}
|
|
|
1566 |
</div>
|
1567 |
</body>
|
1568 |
</html>
|
1569 |
-
""".format(source_files_added, exceptions_html)
|
1570 |
elif metadatas:
|
1571 |
source_files_added = """\
|
1572 |
<html>
|
1573 |
<body>
|
1574 |
<p>
|
1575 |
-
|
1576 |
</p>
|
1577 |
<div style="overflow-y: auto;height:400px">
|
1578 |
-
{
|
1579 |
</div>
|
1580 |
</body>
|
1581 |
</html>
|
1582 |
-
""".format(source_files_added)
|
1583 |
elif exceptions_html:
|
1584 |
source_files_added = """\
|
1585 |
<html>
|
@@ -1594,6 +1711,31 @@ def get_source_files(db, exceptions=None):
|
|
1594 |
</html>
|
1595 |
""".format(exceptions_html)
|
1596 |
else:
|
1597 |
-
|
|
|
|
|
|
|
1598 |
|
1599 |
return source_files_added
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import uuid
|
10 |
import filelock
|
11 |
import pandas as pd
|
12 |
+
import requests
|
13 |
import tabulate
|
14 |
|
15 |
+
# This is a hack to prevent Gradio from phoning home when it gets imported
|
16 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
17 |
+
|
18 |
+
|
19 |
+
def my_get(url, **kwargs):
|
20 |
+
print('Gradio HTTP request redirected to localhost :)', flush=True)
|
21 |
+
kwargs.setdefault('allow_redirects', True)
|
22 |
+
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
23 |
+
|
24 |
+
|
25 |
+
original_get = requests.get
|
26 |
+
requests.get = my_get
|
27 |
+
import gradio as gr
|
28 |
+
|
29 |
+
requests.get = original_get
|
30 |
+
|
31 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
|
32 |
from prompter import Prompter, \
|
33 |
+
prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt, non_hf_types
|
34 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
35 |
+
ping, get_short_name, get_url, makedirs, get_kwargs
|
36 |
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
|
37 |
inputs_kwargs_list, get_cutoffs, scratch_base_dir
|
38 |
|
|
|
39 |
from apscheduler.schedulers.background import BackgroundScheduler
|
40 |
|
41 |
|
|
|
43 |
allow_api = kwargs['allow_api']
|
44 |
is_public = kwargs['is_public']
|
45 |
is_hf = kwargs['is_hf']
|
46 |
+
memory_restriction_level = kwargs['memory_restriction_level']
|
47 |
n_gpus = kwargs['n_gpus']
|
48 |
admin_pass = kwargs['admin_pass']
|
49 |
model_state0 = kwargs['model_state0']
|
50 |
score_model_state0 = kwargs['score_model_state0']
|
|
|
51 |
dbs = kwargs['dbs']
|
52 |
db_type = kwargs['db_type']
|
53 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
|
|
56 |
enable_sources_list = kwargs['enable_sources_list']
|
57 |
enable_url_upload = kwargs['enable_url_upload']
|
58 |
enable_text_upload = kwargs['enable_text_upload']
|
|
|
59 |
use_openai_embedding = kwargs['use_openai_embedding']
|
60 |
hf_embedding_model = kwargs['hf_embedding_model']
|
61 |
enable_captions = kwargs['enable_captions']
|
|
|
64 |
caption_loader = kwargs['caption_loader']
|
65 |
|
66 |
# easy update of kwargs needed for evaluate() etc.
|
67 |
+
queue = True
|
68 |
+
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
69 |
kwargs.update(locals())
|
70 |
|
71 |
if 'mbart-' in kwargs['model_lower']:
|
|
|
92 |
"""
|
93 |
else:
|
94 |
description = more_info
|
95 |
+
description += "If this host is busy, try [12B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
96 |
+
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)</p>"""
|
97 |
if is_hf:
|
98 |
description += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
99 |
|
|
|
111 |
else:
|
112 |
css_code = """footer {visibility: hidden}"""
|
113 |
css_code += """
|
114 |
+
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
115 |
body.dark{#warning {background-color: #555555};}
|
116 |
#small_btn {
|
117 |
margin: 0.6em 0em 0.55em 0;
|
|
|
148 |
|
149 |
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
|
150 |
|
151 |
+
if kwargs['gradio_offline_level'] >= 0:
|
152 |
+
# avoid GoogleFont that pulls from internet
|
153 |
+
if kwargs['gradio_offline_level'] == 1:
|
154 |
+
# front end would still have to download fonts or have cached it at some point
|
155 |
+
base_font = 'Source Sans Pro'
|
156 |
+
else:
|
157 |
+
base_font = 'Helvetica'
|
158 |
+
theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'),
|
159 |
+
font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'))
|
160 |
+
else:
|
161 |
+
theme_kwargs = dict()
|
162 |
+
|
163 |
+
theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs)
|
164 |
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
|
165 |
callback = gr.CSVLogger()
|
166 |
|
|
|
202 |
lora_options_state = gr.State([lora_options])
|
203 |
my_db_state = gr.State([None, None])
|
204 |
chat_state = gr.State({})
|
205 |
+
# make user default first and default choice, dedup
|
206 |
+
docs_state00 = kwargs['document_choice'] + ['All', 'Only', 'None']
|
207 |
+
docs_state0 = []
|
208 |
+
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
|
209 |
+
docs_state = gr.State(docs_state0) # first is chosen as default
|
210 |
gr.Markdown(f"""
|
211 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
212 |
|
|
|
291 |
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
292 |
type='value')
|
293 |
with gr.Row():
|
294 |
+
clear_chat_btn = gr.Button(value="Clear Chat", visible=True).style(size='sm')
|
295 |
+
export_chats_btn = gr.Button(value="Export Chats to Download").style(size='sm')
|
296 |
+
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True).style(size='sm')
|
297 |
+
add_to_chats_btn = gr.Button("Import Chats from Upload").style(size='sm')
|
298 |
with gr.Row():
|
299 |
chats_file = gr.File(interactive=False, label="Download Exported Chats")
|
300 |
chatsup_output = gr.File(label="Upload Chat File(s)",
|
|
|
302 |
file_count='multiple',
|
303 |
elem_id="warning", elem_classes="feedback")
|
304 |
with gr.TabItem("Data Source"):
|
305 |
+
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/docs/README_LangChain.md',
|
306 |
from_str=True)
|
307 |
gr.HTML(value=f"""LangChain Support Disabled<p>
|
308 |
Run:<p>
|
|
|
335 |
with data_row2:
|
336 |
with gr.Column(scale=50):
|
337 |
document_choice = gr.Dropdown(docs_state.value,
|
338 |
+
label="Choose Subset of Doc(s) in Collection [click get sources to update]",
|
339 |
value=docs_state.value[0],
|
340 |
interactive=True,
|
341 |
multiselect=True,
|
|
|
345 |
).style(full_width=False, size='sm')
|
346 |
show_sources_btn = gr.Button(value="Show Sources",
|
347 |
).style(full_width=False, size='sm')
|
348 |
+
refresh_sources_btn = gr.Button(value="Refresh Sources",
|
349 |
+
).style(full_width=False, size='sm')
|
350 |
|
351 |
# import control
|
352 |
if kwargs['langchain_mode'] != 'Disabled':
|
|
|
410 |
with sources_row3:
|
411 |
with gr.Column(scale=1):
|
412 |
file_source = gr.File(interactive=False,
|
413 |
+
label="Download File w/Sources [click get sources to make file]")
|
414 |
with gr.Column(scale=2):
|
415 |
pass
|
416 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
|
|
446 |
)
|
447 |
# FIXME: https://github.com/h2oai/h2ogpt/issues/106
|
448 |
if os.getenv('TESTINGFAIL'):
|
449 |
+
max_beams = 8 if not (memory_restriction_level or is_public) else 1
|
450 |
else:
|
451 |
max_beams = 1
|
452 |
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
453 |
value=min(max_beams, kwargs['num_beams']), label="Beams",
|
454 |
info="Number of searches for optimal overall probability. "
|
455 |
"Uses more GPU memory/compute")
|
456 |
+
# FIXME: 2048 should be tokenizer.model_max_length, but may not even have model yet
|
457 |
+
if kwargs['max_new_tokens']:
|
458 |
+
max_max_new_tokens = kwargs['max_new_tokens']
|
459 |
+
elif memory_restriction_level == 1:
|
460 |
+
max_max_new_tokens = 768
|
461 |
+
elif memory_restriction_level == 2:
|
462 |
+
max_max_new_tokens = 512
|
463 |
+
elif memory_restriction_level >= 3:
|
464 |
+
max_max_new_tokens = 256
|
465 |
+
else:
|
466 |
+
max_max_new_tokens = 2048
|
467 |
max_new_tokens = gr.Slider(
|
468 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
469 |
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
|
|
|
495 |
visible=not is_public)
|
496 |
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
|
497 |
visible=not is_public)
|
498 |
+
count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", visible=not is_public)
|
499 |
+
chat_token_count = gr.Textbox(label="Chat Token Count", value=None,
|
500 |
+
visible=not is_public, interactive=False)
|
501 |
+
top_k_docs = gr.Slider(minimum=0, maximum=20, step=1,
|
502 |
+
value=kwargs['top_k_docs'],
|
503 |
+
label="Number of document chunks",
|
504 |
+
info="For LangChain",
|
505 |
+
visible=not is_public)
|
506 |
|
507 |
with gr.TabItem("Models"):
|
508 |
+
load_msg = "Load-Unload Model/LORA [unload works if did not use --base_model]" if not is_public \
|
509 |
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
|
510 |
+
load_msg2 = "Load-Unload Model/LORA 2 [unload works if did not use --base_model]" if not is_public \
|
511 |
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
|
512 |
compare_checkbox = gr.components.Checkbox(label="Compare Mode",
|
513 |
value=False, visible=not is_public)
|
|
|
521 |
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
|
522 |
value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
523 |
with gr.Column(scale=1):
|
524 |
+
load_model_button = gr.Button(load_msg).style(full_width=False, size='sm')
|
525 |
model_load8bit_checkbox = gr.components.Checkbox(
|
526 |
label="Load 8-bit [requires support]",
|
527 |
value=kwargs['load_8bit'])
|
|
|
529 |
label="Choose Devices [If not Checked, use all GPUs]",
|
530 |
value=kwargs['infer_devices'])
|
531 |
model_gpu = gr.Dropdown(n_gpus_list,
|
532 |
+
label="GPU ID [-1 = all GPUs, if Choose is enabled]",
|
533 |
value=kwargs['gpu_id'])
|
534 |
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
|
535 |
interactive=False)
|
536 |
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
|
537 |
visible=kwargs['show_lora'], interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
col_model2 = gr.Column(visible=False)
|
539 |
with col_model2:
|
540 |
with gr.Row():
|
|
|
545 |
value=no_lora_str,
|
546 |
visible=kwargs['show_lora'])
|
547 |
with gr.Column(scale=1):
|
548 |
+
load_model_button2 = gr.Button(load_msg2).style(full_width=False, size='sm')
|
549 |
model_load8bit_checkbox2 = gr.components.Checkbox(
|
550 |
label="Load 8-bit 2 [requires support]",
|
551 |
value=kwargs['load_8bit'])
|
|
|
554 |
value=kwargs[
|
555 |
'infer_devices'])
|
556 |
model_gpu2 = gr.Dropdown(n_gpus_list,
|
557 |
+
label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
|
558 |
value=kwargs['gpu_id'])
|
559 |
# no model/lora loaded ever in model2 by default
|
560 |
model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
|
561 |
lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
|
562 |
visible=kwargs['show_lora'])
|
563 |
+
with gr.Row():
|
564 |
+
with gr.Column(scale=50):
|
565 |
+
new_model = gr.Textbox(label="New Model HF name/path")
|
566 |
+
with gr.Row():
|
567 |
+
add_model_button = gr.Button("Add new model name").style(full_width=False, size='sm')
|
568 |
+
with gr.Column(scale=50):
|
569 |
+
new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
|
570 |
+
with gr.Row():
|
571 |
+
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora']).style(
|
572 |
+
full_width=False, size='sm')
|
573 |
with gr.TabItem("System"):
|
574 |
admin_row = gr.Row()
|
575 |
with admin_row:
|
|
|
586 |
with gr.Row():
|
587 |
zip_btn = gr.Button("Zip")
|
588 |
zip_text = gr.Textbox(label="Zip file name", interactive=False)
|
589 |
+
file_output = gr.File(interactive=False, label="Zip file to Download")
|
590 |
with gr.Row():
|
591 |
s3up_btn = gr.Button("S3UP")
|
592 |
s3up_text = gr.Textbox(label='S3UP result', interactive=False)
|
|
|
598 |
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
599 |
if 'h2ogpt-research' in kwargs['base_model']:
|
600 |
description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
|
601 |
+
description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md">Terms of Service</a></i></li></ul></p>"""
|
602 |
gr.Markdown(value=description, show_label=False, interactive=False)
|
603 |
|
604 |
# Get flagged data
|
|
|
689 |
api_name='add_txt_to_my' if allow_api else None) \
|
690 |
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
691 |
|
692 |
+
get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0)
|
693 |
|
694 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
695 |
def clear_doc_choice():
|
696 |
+
return gr.Dropdown.update(choices=docs_state0, value=[docs_state0[0]])
|
697 |
|
698 |
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
|
699 |
|
700 |
def update_dropdown(x):
|
701 |
+
return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
|
702 |
|
|
|
703 |
get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
|
704 |
queue=queue,
|
705 |
api_name='get_sources' if allow_api else None) \
|
706 |
.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
|
707 |
# show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
|
708 |
+
show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
|
709 |
+
show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
|
710 |
+
api_name='show_sources' if allow_api else None)
|
711 |
+
|
712 |
+
# Get inputs to evaluate() and make_db()
|
713 |
+
# don't deepcopy, can contain model itself
|
714 |
+
all_kwargs = kwargs.copy()
|
715 |
+
all_kwargs.update(locals())
|
716 |
+
|
717 |
+
refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
|
718 |
+
**get_kwargs(update_and_get_source_files_given_langchain_mode,
|
719 |
+
exclude_names=['db1', 'langchain_mode'],
|
720 |
+
**all_kwargs))
|
721 |
+
refresh_sources_btn.click(fn=refresh_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
|
722 |
+
api_name='refresh_sources' if allow_api else None)
|
723 |
|
724 |
def check_admin_pass(x):
|
725 |
return gr.update(visible=x == admin_pass)
|
|
|
730 |
admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
|
731 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
732 |
|
|
|
|
|
|
|
|
|
733 |
inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
|
734 |
from functools import partial
|
735 |
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
|
|
|
779 |
""" Similar to user() """
|
780 |
args_list = list(args)
|
781 |
|
782 |
+
if memory_restriction_level > 0:
|
783 |
+
max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
|
784 |
+
else:
|
785 |
+
max_length_tokenize = 2048 - 256
|
786 |
cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
|
787 |
smodel = score_model_state0[0]
|
788 |
stokenizer = score_model_state0[1]
|
|
|
879 |
# e.g. when user just hits enter in textbox,
|
880 |
# else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
|
881 |
user_message1 = '\n'
|
882 |
+
# ensure good visually, else markdown ignores multiple \n
|
883 |
+
user_message1 = user_message1.replace('\n', '<br>')
|
884 |
|
885 |
history = args_list[-1]
|
886 |
if undo and history:
|
|
|
900 |
# FIXME: compare, same history for now
|
901 |
return history + [[user_message1, None]]
|
902 |
|
903 |
+
def history_to_context(history, langchain_mode1, prompt_type1, chat1):
|
904 |
+
# ensure output will be unique to models
|
905 |
+
# FIXME: hard-coded 2048 implicitly passed:
|
906 |
+
_, _, _, max_prompt_length = get_cutoffs(memory_restriction_level, for_context=True)
|
907 |
+
history = copy.deepcopy(history)
|
908 |
+
|
909 |
+
context1 = ''
|
910 |
+
if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
|
911 |
+
context1 = ''
|
912 |
+
# - 1 below because current instruction already in history from user()
|
913 |
+
for histi in range(0, len(history) - 1):
|
914 |
+
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
915 |
+
prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
|
916 |
+
chat1, reduced=True)
|
917 |
+
# md -> back to text, maybe not super important if model trained enough
|
918 |
+
if not kwargs['keep_sources_in_context']:
|
919 |
+
from gpt_langchain import source_prefix, source_postfix
|
920 |
+
import re
|
921 |
+
prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
|
922 |
+
flags=re.DOTALL)
|
923 |
+
if prompt.endswith('\n<p>'):
|
924 |
+
prompt = prompt[:-4]
|
925 |
+
prompt = prompt.replace('<br>', chat_sep)
|
926 |
+
if not prompt.endswith(chat_sep):
|
927 |
+
prompt += chat_sep
|
928 |
+
# most recent first, add older if can
|
929 |
+
# only include desired chat history
|
930 |
+
if len(prompt + context1) > max_prompt_length:
|
931 |
+
break
|
932 |
+
context1 = prompt + context1
|
933 |
+
|
934 |
+
_, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
|
935 |
+
reduced=True)
|
936 |
+
if context1 and not context1.endswith(chat_sep):
|
937 |
+
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
938 |
+
return context1
|
939 |
+
|
940 |
def bot(*args, retry=False):
|
941 |
"""
|
942 |
bot that consumes history for user input
|
|
|
968 |
history = []
|
969 |
yield history, ''
|
970 |
return
|
|
|
|
|
|
|
971 |
instruction1 = history[-1][0]
|
972 |
if not instruction1:
|
973 |
# reject empty query, can sometimes go nuts
|
974 |
history = []
|
975 |
yield history, ''
|
976 |
return
|
977 |
+
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
978 |
+
chat1 = args_list[eval_func_param_names.index('chat')]
|
979 |
+
context1 = history_to_context(history, langchain_mode1, prompt_type1, chat1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
980 |
args_list[0] = instruction1 # override original instruction with history from user
|
981 |
args_list[2] = context1
|
982 |
fun1 = partial(evaluate,
|
|
|
984 |
my_db_state1,
|
985 |
**kwargs_evaluate)
|
986 |
try:
|
987 |
+
for output_fun in fun1(*tuple(args_list)):
|
988 |
+
output = output_fun['response']
|
989 |
+
extra = output_fun['sources'] # FIXME: can show sources in separate text box etc.
|
990 |
+
# ensure good visually, else markdown ignores multiple \n
|
991 |
+
bot_message = output.replace('\n', '<br>')
|
992 |
history[-1][1] = bot_message
|
993 |
yield history, ''
|
994 |
except StopIteration:
|
|
|
1145 |
if len(stepy) != 2:
|
1146 |
# something off
|
1147 |
return False
|
1148 |
+
questionx = stepx[0].replace('<p>', '').replace('</p>', '') if stepx[0] is not None else None
|
1149 |
+
answerx = stepx[1].replace('<p>', '').replace('</p>', '') if stepx[1] is not None else None
|
1150 |
|
1151 |
+
questiony = stepy[0].replace('<p>', '').replace('</p>', '') if stepy[0] is not None else None
|
1152 |
+
answery = stepy[1].replace('<p>', '').replace('</p>', '') if stepy[1] is not None else None
|
1153 |
|
1154 |
if questionx != questiony or answerx != answery:
|
1155 |
return False
|
|
|
1299 |
lora_weights = ''
|
1300 |
|
1301 |
all_kwargs1['lora_weights'] = lora_weights.strip()
|
1302 |
+
model1, tokenizer1, device1 = get_model(reward_type=False,
|
1303 |
+
**get_kwargs(get_model, exclude_names=['reward_type'],
|
1304 |
+
**all_kwargs1))
|
1305 |
clear_torch_cache()
|
1306 |
|
1307 |
if kwargs['debug']:
|
|
|
1322 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1323 |
nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
|
1324 |
if not is_public:
|
1325 |
+
load_model_event = load_model_button.click(**load_model_args, api_name='load_model' if allow_api else None) \
|
1326 |
.then(**prompt_update_args) \
|
1327 |
.then(**chatbot_update_args) \
|
1328 |
.then(**nochat_update_args) \
|
|
|
1335 |
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
|
1336 |
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
|
1337 |
if not is_public:
|
1338 |
+
load_model_event2 = load_model_button2.click(**load_model_args2,
|
1339 |
+
api_name='load_model2' if allow_api else None) \
|
1340 |
.then(**prompt_update_args2) \
|
1341 |
.then(**chatbot_update_args2) \
|
1342 |
.then(clear_torch_cache)
|
|
|
1412 |
submit_event3d, submit_event3f,
|
1413 |
submit_event_nochat],
|
1414 |
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
1415 |
+
|
1416 |
+
def count_chat_tokens(model_state1, chat1, prompt_type1):
|
1417 |
+
if model_state1 and not isinstance(model_state1[1], str):
|
1418 |
+
tokenizer = model_state1[1]
|
1419 |
+
elif model_state0 and not isinstance(model_state0[1], str):
|
1420 |
+
tokenizer = model_state0[1]
|
1421 |
+
else:
|
1422 |
+
tokenizer = None
|
1423 |
+
if tokenizer is not None:
|
1424 |
+
langchain_mode1 = 'ChatLLM'
|
1425 |
+
# fake user message to mimic bot()
|
1426 |
+
chat1 = copy.deepcopy(chat1)
|
1427 |
+
chat1 = chat1 + [['user_message1', None]]
|
1428 |
+
context1 = history_to_context(chat1, langchain_mode1, prompt_type1, chat1)
|
1429 |
+
return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
|
1430 |
+
else:
|
1431 |
+
return "N/A"
|
1432 |
+
|
1433 |
+
count_chat_tokens_btn.click(fn=count_chat_tokens, inputs=[model_state, text_output, prompt_type],
|
1434 |
+
outputs=chat_token_count, api_name='count_tokens' if allow_api else None)
|
1435 |
+
|
1436 |
demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
|
1437 |
|
1438 |
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
|
|
|
1441 |
scheduler = BackgroundScheduler()
|
1442 |
scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
|
1443 |
if is_public and \
|
1444 |
+
kwargs['base_model'] not in non_hf_types:
|
1445 |
# FIXME: disable for gptj, langchain or gpt4all modify print itself
|
1446 |
# FIXME: and any multi-threaded/async print will enter model output!
|
1447 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
|
|
1450 |
# import control
|
1451 |
if kwargs['langchain_mode'] == 'Disabled' and \
|
1452 |
os.environ.get("TEST_LANGCHAIN_IMPORT") and \
|
1453 |
+
kwargs['base_model'] not in non_hf_types:
|
1454 |
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
1455 |
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
1456 |
|
1457 |
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
1458 |
favicon_path=favicon_path, prevent_thread_lock=True,
|
1459 |
auth=kwargs['auth'])
|
1460 |
+
if kwargs['verbose']:
|
1461 |
+
print("Started GUI", flush=True)
|
1462 |
if kwargs['block_gradio_exit']:
|
1463 |
demo.block_thread()
|
1464 |
|
|
|
1487 |
return inputs_list
|
1488 |
|
1489 |
|
1490 |
+
def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
|
1491 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
1492 |
source_files_added = "NA"
|
1493 |
source_list = []
|
|
|
1510 |
sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
|
1511 |
with open(sources_file, "wt") as f:
|
1512 |
f.write(source_files_added)
|
1513 |
+
source_list = docs_state0 + source_list
|
1514 |
return sources_file, source_list
|
1515 |
|
1516 |
|
|
|
1574 |
if langchain_mode == 'MyData':
|
1575 |
if db1[0] is not None:
|
1576 |
# then add
|
1577 |
+
db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type)
|
1578 |
else:
|
1579 |
assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
|
1580 |
# then create
|
|
|
1589 |
hf_embedding_model=hf_embedding_model)
|
1590 |
if db1[0] is None:
|
1591 |
db1[1] = None
|
1592 |
+
source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
|
1593 |
return db1, x, y, source_files_added
|
1594 |
else:
|
1595 |
persist_directory = 'db_dir_%s' % langchain_mode
|
1596 |
if langchain_mode in dbs and dbs[langchain_mode] is not None:
|
1597 |
# then add
|
1598 |
+
db, num_new_sources, new_sources_metadata = add_to_db(dbs[langchain_mode], sources, db_type=db_type)
|
1599 |
else:
|
1600 |
# then create
|
1601 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
|
|
1607 |
# NOTE we do not return db, because function call always same code path
|
1608 |
# return dbs[langchain_mode], x, y
|
1609 |
# db in this code path is updated in place
|
1610 |
+
source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions)
|
1611 |
return x, y, source_files_added
|
1612 |
|
1613 |
|
1614 |
+
def get_db(db1, langchain_mode, dbs=None):
|
1615 |
with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
|
1616 |
if langchain_mode in ['wiki_full']:
|
1617 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
|
|
1622 |
db = dbs[langchain_mode]
|
1623 |
else:
|
1624 |
db = None
|
1625 |
+
return db
|
1626 |
+
|
1627 |
+
|
1628 |
+
def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
|
1629 |
+
db = get_db(db1, langchain_mode, dbs=dbs)
|
1630 |
+
return get_source_files(db=db, exceptions=None)
|
1631 |
|
1632 |
|
1633 |
+
def get_source_files(db=None, exceptions=None, metadatas=None):
|
1634 |
if exceptions is None:
|
1635 |
exceptions = []
|
1636 |
|
1637 |
+
# only should be one source, not confused
|
1638 |
+
assert db is not None or metadatas is not None
|
1639 |
+
|
1640 |
+
if metadatas is None:
|
1641 |
+
source_label = "Sources:"
|
1642 |
+
if db is not None:
|
1643 |
+
metadatas = db.get()['metadatas']
|
1644 |
+
else:
|
1645 |
+
metadatas = []
|
1646 |
+
adding_new = False
|
1647 |
else:
|
1648 |
+
source_label = "New Sources:"
|
1649 |
+
adding_new = True
|
1650 |
|
1651 |
# below automatically de-dups
|
1652 |
from gpt_langchain import get_url
|
|
|
1675 |
<html>
|
1676 |
<body>
|
1677 |
<p>
|
1678 |
+
{0} <br>
|
1679 |
</p>
|
1680 |
<div style="overflow-y: auto;height:400px">
|
|
|
1681 |
{1}
|
1682 |
+
{2}
|
1683 |
</div>
|
1684 |
</body>
|
1685 |
</html>
|
1686 |
+
""".format(source_label, source_files_added, exceptions_html)
|
1687 |
elif metadatas:
|
1688 |
source_files_added = """\
|
1689 |
<html>
|
1690 |
<body>
|
1691 |
<p>
|
1692 |
+
{0} <br>
|
1693 |
</p>
|
1694 |
<div style="overflow-y: auto;height:400px">
|
1695 |
+
{1}
|
1696 |
</div>
|
1697 |
</body>
|
1698 |
</html>
|
1699 |
+
""".format(source_label, source_files_added)
|
1700 |
elif exceptions_html:
|
1701 |
source_files_added = """\
|
1702 |
<html>
|
|
|
1711 |
</html>
|
1712 |
""".format(exceptions_html)
|
1713 |
else:
|
1714 |
+
if adding_new:
|
1715 |
+
source_files_added = "No New Sources"
|
1716 |
+
else:
|
1717 |
+
source_files_added = "No Sources"
|
1718 |
|
1719 |
return source_files_added
|
1720 |
+
|
1721 |
+
|
1722 |
+
def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=None, first_para=None,
|
1723 |
+
text_limit=None, chunk=None, chunk_size=None,
|
1724 |
+
user_path=None, db_type=None, load_db_if_exists=None,
|
1725 |
+
n_jobs=None, verbose=None):
|
1726 |
+
db = get_db(db1, langchain_mode, dbs=dbs)
|
1727 |
+
|
1728 |
+
from gpt_langchain import make_db
|
1729 |
+
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
|
1730 |
+
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1731 |
+
first_para=first_para, text_limit=text_limit, chunk=chunk,
|
1732 |
+
chunk_size=chunk_size,
|
1733 |
+
langchain_mode=langchain_mode,
|
1734 |
+
user_path=user_path,
|
1735 |
+
db_type=db_type,
|
1736 |
+
load_db_if_exists=load_db_if_exists,
|
1737 |
+
db=db,
|
1738 |
+
n_jobs=n_jobs,
|
1739 |
+
verbose=verbose)
|
1740 |
+
# return only new sources with text saying such
|
1741 |
+
return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
|
gradio_themes.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
from __future__ import annotations
|
|
|
|
|
|
|
2 |
from gradio.themes.soft import Soft
|
3 |
from gradio.themes import Color
|
4 |
-
from gradio.themes.utils import colors, sizes
|
5 |
|
6 |
h2o_yellow = Color(
|
7 |
name="yellow",
|
@@ -43,6 +46,22 @@ class H2oTheme(Soft):
|
|
43 |
spacing_size: sizes.Size | str = sizes.spacing_md,
|
44 |
radius_size: sizes.Size | str = sizes.radius_md,
|
45 |
text_size: sizes.Size | str = sizes.text_lg,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
):
|
47 |
super().__init__(
|
48 |
primary_hue=primary_hue,
|
@@ -51,6 +70,8 @@ class H2oTheme(Soft):
|
|
51 |
spacing_size=spacing_size,
|
52 |
radius_size=radius_size,
|
53 |
text_size=text_size,
|
|
|
|
|
54 |
)
|
55 |
super().set(
|
56 |
link_text_color="#3344DD",
|
@@ -89,6 +110,22 @@ class SoftTheme(Soft):
|
|
89 |
spacing_size: sizes.Size | str = sizes.spacing_md,
|
90 |
radius_size: sizes.Size | str = sizes.radius_md,
|
91 |
text_size: sizes.Size | str = sizes.text_md,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
):
|
93 |
super().__init__(
|
94 |
primary_hue=primary_hue,
|
@@ -97,6 +134,8 @@ class SoftTheme(Soft):
|
|
97 |
spacing_size=spacing_size,
|
98 |
radius_size=radius_size,
|
99 |
text_size=text_size,
|
|
|
|
|
100 |
)
|
101 |
|
102 |
|
@@ -125,7 +164,7 @@ def get_h2o_title(title):
|
|
125 |
<h1 style="line-height:60px">{title}</h1>
|
126 |
</div>
|
127 |
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
128 |
-
<img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/h2o-qr.png></img>
|
129 |
</div>
|
130 |
"""
|
131 |
|
|
|
1 |
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Iterable
|
4 |
+
|
5 |
from gradio.themes.soft import Soft
|
6 |
from gradio.themes import Color
|
7 |
+
from gradio.themes.utils import colors, sizes, fonts
|
8 |
|
9 |
h2o_yellow = Color(
|
10 |
name="yellow",
|
|
|
46 |
spacing_size: sizes.Size | str = sizes.spacing_md,
|
47 |
radius_size: sizes.Size | str = sizes.radius_md,
|
48 |
text_size: sizes.Size | str = sizes.text_lg,
|
49 |
+
font: fonts.Font
|
50 |
+
| str
|
51 |
+
| Iterable[fonts.Font | str] = (
|
52 |
+
fonts.GoogleFont("Montserrat"),
|
53 |
+
"ui-sans-serif",
|
54 |
+
"system-ui",
|
55 |
+
"sans-serif",
|
56 |
+
),
|
57 |
+
font_mono: fonts.Font
|
58 |
+
| str
|
59 |
+
| Iterable[fonts.Font | str] = (
|
60 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
61 |
+
"ui-monospace",
|
62 |
+
"Consolas",
|
63 |
+
"monospace",
|
64 |
+
),
|
65 |
):
|
66 |
super().__init__(
|
67 |
primary_hue=primary_hue,
|
|
|
70 |
spacing_size=spacing_size,
|
71 |
radius_size=radius_size,
|
72 |
text_size=text_size,
|
73 |
+
font=font,
|
74 |
+
font_mono=font_mono,
|
75 |
)
|
76 |
super().set(
|
77 |
link_text_color="#3344DD",
|
|
|
110 |
spacing_size: sizes.Size | str = sizes.spacing_md,
|
111 |
radius_size: sizes.Size | str = sizes.radius_md,
|
112 |
text_size: sizes.Size | str = sizes.text_md,
|
113 |
+
font: fonts.Font
|
114 |
+
| str
|
115 |
+
| Iterable[fonts.Font | str] = (
|
116 |
+
fonts.GoogleFont("Montserrat"),
|
117 |
+
"ui-sans-serif",
|
118 |
+
"system-ui",
|
119 |
+
"sans-serif",
|
120 |
+
),
|
121 |
+
font_mono: fonts.Font
|
122 |
+
| str
|
123 |
+
| Iterable[fonts.Font | str] = (
|
124 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
125 |
+
"ui-monospace",
|
126 |
+
"Consolas",
|
127 |
+
"monospace",
|
128 |
+
),
|
129 |
):
|
130 |
super().__init__(
|
131 |
primary_hue=primary_hue,
|
|
|
134 |
spacing_size=spacing_size,
|
135 |
radius_size=radius_size,
|
136 |
text_size=text_size,
|
137 |
+
font=font,
|
138 |
+
font_mono=font_mono,
|
139 |
)
|
140 |
|
141 |
|
|
|
164 |
<h1 style="line-height:60px">{title}</h1>
|
165 |
</div>
|
166 |
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
167 |
+
<img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png></img>
|
168 |
</div>
|
169 |
"""
|
170 |
|
h2oai_pipeline.py
CHANGED
@@ -2,36 +2,57 @@ from transformers import TextGenerationPipeline
|
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
4 |
from stopping import get_stopping
|
5 |
-
|
6 |
-
prompt_type = "human_bot"
|
7 |
-
human = "<human>:"
|
8 |
-
bot = "<bot>:"
|
9 |
-
|
10 |
-
# human-bot interaction like OIG dataset
|
11 |
-
prompt = """{human} {instruction}
|
12 |
-
{bot}""".format(
|
13 |
-
human=human,
|
14 |
-
instruction="{instruction}",
|
15 |
-
bot=bot,
|
16 |
-
)
|
17 |
|
18 |
|
19 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
20 |
-
def __init__(self, *args,
|
21 |
-
sanitize_bot_response=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
super().__init__(*args, **kwargs)
|
23 |
-
self.use_prompter = use_prompter
|
24 |
self.prompt_text = None
|
|
|
|
|
|
|
25 |
if self.use_prompter:
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
else:
|
29 |
self.prompter = None
|
|
|
|
|
|
|
30 |
self.sanitize_bot_response = sanitize_bot_response
|
|
|
31 |
|
32 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
33 |
-
|
|
|
|
|
34 |
self.prompt_text = prompt_text
|
|
|
|
|
|
|
35 |
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
36 |
**generate_kwargs)
|
37 |
|
@@ -43,12 +64,65 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
43 |
outputs = rec['generated_text']
|
44 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
45 |
sanitize_bot_response=self.sanitize_bot_response)
|
|
|
|
|
46 |
else:
|
47 |
-
outputs = rec['generated_text']
|
48 |
rec['generated_text'] = outputs
|
49 |
return records
|
50 |
|
51 |
def _forward(self, model_inputs, **generate_kwargs):
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
4 |
from stopping import get_stopping
|
5 |
+
from prompter import Prompter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
9 |
+
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
10 |
+
sanitize_bot_response=True,
|
11 |
+
use_prompter=True, prompter=None, prompt_type=None,
|
12 |
+
max_input_tokens=2048 - 256, **kwargs):
|
13 |
+
"""
|
14 |
+
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
15 |
+
:param args:
|
16 |
+
:param debug:
|
17 |
+
:param chat:
|
18 |
+
:param stream_output:
|
19 |
+
:param sanitize_bot_response:
|
20 |
+
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
|
21 |
+
:param prompter: prompter, can pass if have already
|
22 |
+
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
|
23 |
+
If use_prompter, then will make prompter and use it.
|
24 |
+
:param max_input_tokens:
|
25 |
+
:param kwargs:
|
26 |
+
"""
|
27 |
super().__init__(*args, **kwargs)
|
|
|
28 |
self.prompt_text = None
|
29 |
+
self.use_prompter = use_prompter
|
30 |
+
self.prompt_type = prompt_type
|
31 |
+
self.prompter = prompter
|
32 |
if self.use_prompter:
|
33 |
+
if self.prompter is not None:
|
34 |
+
assert self.prompter.prompt_type is not None
|
35 |
+
else:
|
36 |
+
self.prompter = Prompter(self.prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
37 |
+
self.human = self.prompter.humanstr
|
38 |
+
self.bot = self.prompter.botstr
|
39 |
+
self.can_stop = True
|
40 |
else:
|
41 |
self.prompter = None
|
42 |
+
self.human = None
|
43 |
+
self.bot = None
|
44 |
+
self.can_stop = False
|
45 |
self.sanitize_bot_response = sanitize_bot_response
|
46 |
+
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
|
47 |
|
48 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
49 |
+
data_point = dict(context='', instruction=prompt_text, input='')
|
50 |
+
if self.prompter is not None:
|
51 |
+
prompt_text = self.prompter.generate_prompt(data_point)
|
52 |
self.prompt_text = prompt_text
|
53 |
+
if handle_long_generation is None:
|
54 |
+
# forces truncation of inputs to avoid critical failure
|
55 |
+
handle_long_generation = 'hole'
|
56 |
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
57 |
**generate_kwargs)
|
58 |
|
|
|
64 |
outputs = rec['generated_text']
|
65 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
66 |
sanitize_bot_response=self.sanitize_bot_response)
|
67 |
+
elif self.bot and self.human:
|
68 |
+
outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
|
69 |
else:
|
70 |
+
outputs = rec['generated_text']
|
71 |
rec['generated_text'] = outputs
|
72 |
return records
|
73 |
|
74 |
def _forward(self, model_inputs, **generate_kwargs):
|
75 |
+
if self.can_stop:
|
76 |
+
stopping_criteria = get_stopping(self.prompt_type, self.tokenizer, self.device, human=self.human,
|
77 |
+
bot=self.bot)
|
78 |
+
generate_kwargs['stopping_criteria'] = stopping_criteria
|
79 |
+
# return super()._forward(model_inputs, **generate_kwargs)
|
80 |
+
return self.__forward(model_inputs, **generate_kwargs)
|
81 |
+
|
82 |
+
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
|
83 |
+
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
|
84 |
+
def __forward(self, model_inputs, **generate_kwargs):
|
85 |
+
input_ids = model_inputs["input_ids"]
|
86 |
+
attention_mask = model_inputs.get("attention_mask", None)
|
87 |
+
# Allow empty prompts
|
88 |
+
if input_ids.shape[1] == 0:
|
89 |
+
input_ids = None
|
90 |
+
attention_mask = None
|
91 |
+
in_b = 1
|
92 |
+
else:
|
93 |
+
in_b = input_ids.shape[0]
|
94 |
+
prompt_text = model_inputs.pop("prompt_text")
|
95 |
+
|
96 |
+
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
97 |
+
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
98 |
+
# generate_kwargs = copy.deepcopy(generate_kwargs)
|
99 |
+
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
100 |
+
if prefix_length > 0:
|
101 |
+
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
102 |
+
"generation_config" in generate_kwargs
|
103 |
+
and generate_kwargs["generation_config"].max_new_tokens is not None
|
104 |
+
)
|
105 |
+
if not has_max_new_tokens:
|
106 |
+
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
107 |
+
generate_kwargs["max_length"] += prefix_length
|
108 |
+
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
109 |
+
"generation_config" in generate_kwargs
|
110 |
+
and generate_kwargs["generation_config"].min_new_tokens is not None
|
111 |
+
)
|
112 |
+
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
113 |
+
generate_kwargs["min_length"] += prefix_length
|
114 |
+
|
115 |
+
# BS x SL
|
116 |
+
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
117 |
+
out_b = generated_sequence.shape[0]
|
118 |
+
if self.framework == "pt":
|
119 |
+
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
120 |
+
elif self.framework == "tf":
|
121 |
+
from transformers import is_tf_available
|
122 |
+
if is_tf_available():
|
123 |
+
import tensorflow as tf
|
124 |
+
generated_sequence = tf.reshape(generated_sequence,
|
125 |
+
(in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
126 |
+
else:
|
127 |
+
raise ValueError("TF not avaialble.")
|
128 |
+
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
prompter.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import time
|
2 |
from enum import Enum
|
3 |
|
|
|
|
|
4 |
|
5 |
class PromptType(Enum):
|
6 |
plain = 0
|
@@ -17,6 +19,10 @@ class PromptType(Enum):
|
|
17 |
open_assistant = 11
|
18 |
wizard_lm = 12
|
19 |
wizard_mega = 13
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
prompt_type_to_model_name = {
|
@@ -26,6 +32,7 @@ prompt_type_to_model_name = {
|
|
26 |
'EleutherAI/pythia-12b',
|
27 |
'EleutherAI/pythia-12b-deduped',
|
28 |
'EleutherAI/gpt-neox-20b',
|
|
|
29 |
'decapoda-research/llama-7b-hf',
|
30 |
'decapoda-research/llama-13b-hf',
|
31 |
'decapoda-research/llama-30b-hf',
|
@@ -39,7 +46,8 @@ prompt_type_to_model_name = {
|
|
39 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
40 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
41 |
'gptj', # internally handles prompting
|
42 |
-
'llama', #
|
|
|
43 |
],
|
44 |
'prompt_answer': [
|
45 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
@@ -47,6 +55,7 @@ prompt_type_to_model_name = {
|
|
47 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
48 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
49 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
|
|
50 |
],
|
51 |
'instruct': [],
|
52 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
@@ -58,7 +67,9 @@ prompt_type_to_model_name = {
|
|
58 |
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
59 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
60 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
61 |
-
'h2oai/h2ogpt-research-oasst1-512-30b',
|
|
|
|
|
62 |
],
|
63 |
'dai_faq': [],
|
64 |
'summarize': [],
|
@@ -83,7 +94,8 @@ for p in PromptType:
|
|
83 |
|
84 |
|
85 |
def get_prompt(prompt_type, chat, context, reduced):
|
86 |
-
if prompt_type in [
|
|
|
87 |
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
88 |
terminate_response = []
|
89 |
chat_sep = ''
|
@@ -95,11 +107,14 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
95 |
chat_sep = '\n'
|
96 |
humanstr = ''
|
97 |
botstr = ''
|
98 |
-
elif prompt_type in [
|
|
|
|
|
|
|
99 |
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
|
100 |
-
|
101 |
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
102 |
-
|
103 |
|
104 |
PreInstruct = """
|
105 |
### Instruction:
|
@@ -112,18 +127,20 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
112 |
PreResponse = """
|
113 |
### Response:
|
114 |
"""
|
115 |
-
if prompt_type in [
|
|
|
116 |
terminate_response = ['### End']
|
117 |
else:
|
118 |
terminate_response = None
|
119 |
chat_sep = '\n'
|
120 |
humanstr = PreInstruct
|
121 |
botstr = PreResponse
|
122 |
-
elif prompt_type in [
|
|
|
123 |
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
|
124 |
-
|
125 |
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
|
126 |
-
|
127 |
|
128 |
PreInstruct = """
|
129 |
### Instruction:
|
@@ -140,10 +157,14 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
140 |
chat_sep = '\n'
|
141 |
humanstr = PreInstruct # first thing human says
|
142 |
botstr = PreResponse # first thing bot says
|
143 |
-
elif prompt_type in [
|
|
|
|
|
|
|
144 |
human = '<human>:'
|
145 |
bot = "<bot>:"
|
146 |
-
if reduced or context or prompt_type in [
|
|
|
147 |
preprompt = ''
|
148 |
else:
|
149 |
cur_date = time.strftime('%Y-%m-%d')
|
@@ -174,7 +195,8 @@ Current Time: {}
|
|
174 |
chat_sep = '\n'
|
175 |
humanstr = human # tag before human talks
|
176 |
botstr = bot # tag before bot talks
|
177 |
-
elif prompt_type in [
|
|
|
178 |
promptA = ''
|
179 |
promptB = 'Answer the following Driverless AI question.\n'
|
180 |
|
@@ -191,7 +213,8 @@ Current Time: {}
|
|
191 |
chat_sep = terminate_response
|
192 |
humanstr = PreInstruct
|
193 |
botstr = PreResponse
|
194 |
-
elif prompt_type in [
|
|
|
195 |
promptA = promptB = PreInput = ''
|
196 |
PreInstruct = '## Main Text\n\n'
|
197 |
PreResponse = '\n\n## Summary\n\n'
|
@@ -199,10 +222,11 @@ Current Time: {}
|
|
199 |
chat_sep = '\n'
|
200 |
humanstr = PreInstruct
|
201 |
botstr = PreResponse
|
202 |
-
elif prompt_type in [
|
|
|
203 |
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
204 |
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
|
205 |
-
|
206 |
|
207 |
PreInstruct = """
|
208 |
### Human:
|
@@ -218,7 +242,8 @@ Current Time: {}
|
|
218 |
chat_sep = '\n'
|
219 |
humanstr = PreInstruct
|
220 |
botstr = PreResponse
|
221 |
-
elif prompt_type in [
|
|
|
222 |
preprompt = ''
|
223 |
prompt_tokens = "<|prompt|>"
|
224 |
answer_tokens = "<|answer|>"
|
@@ -232,7 +257,8 @@ Current Time: {}
|
|
232 |
chat_sep = eos
|
233 |
humanstr = prompt_tokens
|
234 |
botstr = answer_tokens
|
235 |
-
elif prompt_type in [
|
|
|
236 |
# From added_tokens.json
|
237 |
preprompt = ''
|
238 |
prompt_tokens = "<|prompter|>"
|
@@ -248,20 +274,22 @@ Current Time: {}
|
|
248 |
chat_sep = eos
|
249 |
humanstr = prompt_tokens
|
250 |
botstr = answer_tokens
|
251 |
-
elif prompt_type in [
|
|
|
252 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
253 |
preprompt = ''
|
254 |
start = ''
|
255 |
promptB = promptA = '%s%s' % (preprompt, start)
|
256 |
PreInstruct = ""
|
257 |
PreInput = None
|
258 |
-
PreResponse = "\n\n### Response"
|
259 |
eos = "</s>"
|
260 |
terminate_response = [PreResponse, eos]
|
261 |
chat_sep = eos
|
262 |
humanstr = promptA
|
263 |
botstr = PreResponse
|
264 |
-
elif prompt_type in [
|
|
|
265 |
preprompt = ''
|
266 |
start = ''
|
267 |
promptB = promptA = '%s%s' % (preprompt, start)
|
@@ -276,6 +304,75 @@ Current Time: {}
|
|
276 |
chat_sep = '\n'
|
277 |
humanstr = PreInstruct
|
278 |
botstr = PreResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
else:
|
280 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
281 |
|
@@ -412,7 +509,7 @@ class Prompter(object):
|
|
412 |
multi_output = len(outputs) > 1
|
413 |
|
414 |
for oi, output in enumerate(outputs):
|
415 |
-
if self.prompt_type in [
|
416 |
output = clean_response(output)
|
417 |
elif prompt is None:
|
418 |
# then use most basic parsing like pipeline
|
|
|
1 |
import time
|
2 |
from enum import Enum
|
3 |
|
4 |
+
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
5 |
+
|
6 |
|
7 |
class PromptType(Enum):
|
8 |
plain = 0
|
|
|
19 |
open_assistant = 11
|
20 |
wizard_lm = 12
|
21 |
wizard_mega = 13
|
22 |
+
instruct_vicuna2 = 14
|
23 |
+
instruct_vicuna3 = 15
|
24 |
+
wizard2 = 16
|
25 |
+
wizard3 = 17
|
26 |
|
27 |
|
28 |
prompt_type_to_model_name = {
|
|
|
32 |
'EleutherAI/pythia-12b',
|
33 |
'EleutherAI/pythia-12b-deduped',
|
34 |
'EleutherAI/gpt-neox-20b',
|
35 |
+
'openlm-research/open_llama_7b_700bt_preview',
|
36 |
'decapoda-research/llama-7b-hf',
|
37 |
'decapoda-research/llama-13b-hf',
|
38 |
'decapoda-research/llama-30b-hf',
|
|
|
46 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
47 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
48 |
'gptj', # internally handles prompting
|
49 |
+
'llama', # plain, or need to choose prompt_type for given TheBloke model
|
50 |
+
'gpt4all_llama', # internally handles prompting
|
51 |
],
|
52 |
'prompt_answer': [
|
53 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
|
|
55 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
56 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
57 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
58 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
59 |
],
|
60 |
'instruct': [],
|
61 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
|
|
67 |
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
68 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
69 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
70 |
+
'h2oai/h2ogpt-research-oasst1-512-30b',
|
71 |
+
'h2oai/h2ogpt-oasst1-falcon-40b',
|
72 |
+
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
73 |
],
|
74 |
'dai_faq': [],
|
75 |
'summarize': [],
|
|
|
94 |
|
95 |
|
96 |
def get_prompt(prompt_type, chat, context, reduced):
|
97 |
+
if prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
98 |
+
PromptType.plain.name]:
|
99 |
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
100 |
terminate_response = []
|
101 |
chat_sep = ''
|
|
|
107 |
chat_sep = '\n'
|
108 |
humanstr = ''
|
109 |
botstr = ''
|
110 |
+
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
111 |
+
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
112 |
+
str(PromptType.instruct_with_end.value),
|
113 |
+
PromptType.instruct_with_end.name]:
|
114 |
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
|
115 |
+
chat and reduced) else ''
|
116 |
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
117 |
+
chat and reduced) else ''
|
118 |
|
119 |
PreInstruct = """
|
120 |
### Instruction:
|
|
|
127 |
PreResponse = """
|
128 |
### Response:
|
129 |
"""
|
130 |
+
if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
|
131 |
+
PromptType.instruct_with_end.name]:
|
132 |
terminate_response = ['### End']
|
133 |
else:
|
134 |
terminate_response = None
|
135 |
chat_sep = '\n'
|
136 |
humanstr = PreInstruct
|
137 |
botstr = PreResponse
|
138 |
+
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
139 |
+
PromptType.quality.name]:
|
140 |
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
|
141 |
+
chat and reduced) else ''
|
142 |
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
|
143 |
+
chat and reduced) else ''
|
144 |
|
145 |
PreInstruct = """
|
146 |
### Instruction:
|
|
|
157 |
chat_sep = '\n'
|
158 |
humanstr = PreInstruct # first thing human says
|
159 |
botstr = PreResponse # first thing bot says
|
160 |
+
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
161 |
+
PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
|
162 |
+
str(PromptType.human_bot_orig.value),
|
163 |
+
PromptType.human_bot_orig.name]:
|
164 |
human = '<human>:'
|
165 |
bot = "<bot>:"
|
166 |
+
if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
167 |
+
PromptType.human_bot.name]:
|
168 |
preprompt = ''
|
169 |
else:
|
170 |
cur_date = time.strftime('%Y-%m-%d')
|
|
|
195 |
chat_sep = '\n'
|
196 |
humanstr = human # tag before human talks
|
197 |
botstr = bot # tag before bot talks
|
198 |
+
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
199 |
+
PromptType.dai_faq.name]:
|
200 |
promptA = ''
|
201 |
promptB = 'Answer the following Driverless AI question.\n'
|
202 |
|
|
|
213 |
chat_sep = terminate_response
|
214 |
humanstr = PreInstruct
|
215 |
botstr = PreResponse
|
216 |
+
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
217 |
+
PromptType.summarize.name]:
|
218 |
promptA = promptB = PreInput = ''
|
219 |
PreInstruct = '## Main Text\n\n'
|
220 |
PreResponse = '\n\n## Summary\n\n'
|
|
|
222 |
chat_sep = '\n'
|
223 |
humanstr = PreInstruct
|
224 |
botstr = PreResponse
|
225 |
+
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
226 |
+
PromptType.instruct_vicuna.name]:
|
227 |
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
228 |
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
|
229 |
+
chat and reduced) else ''
|
230 |
|
231 |
PreInstruct = """
|
232 |
### Human:
|
|
|
242 |
chat_sep = '\n'
|
243 |
humanstr = PreInstruct
|
244 |
botstr = PreResponse
|
245 |
+
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
246 |
+
PromptType.prompt_answer.name]:
|
247 |
preprompt = ''
|
248 |
prompt_tokens = "<|prompt|>"
|
249 |
answer_tokens = "<|answer|>"
|
|
|
257 |
chat_sep = eos
|
258 |
humanstr = prompt_tokens
|
259 |
botstr = answer_tokens
|
260 |
+
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
261 |
+
PromptType.open_assistant.name]:
|
262 |
# From added_tokens.json
|
263 |
preprompt = ''
|
264 |
prompt_tokens = "<|prompter|>"
|
|
|
274 |
chat_sep = eos
|
275 |
humanstr = prompt_tokens
|
276 |
botstr = answer_tokens
|
277 |
+
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
278 |
+
PromptType.wizard_lm.name]:
|
279 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
280 |
preprompt = ''
|
281 |
start = ''
|
282 |
promptB = promptA = '%s%s' % (preprompt, start)
|
283 |
PreInstruct = ""
|
284 |
PreInput = None
|
285 |
+
PreResponse = "\n\n### Response\n"
|
286 |
eos = "</s>"
|
287 |
terminate_response = [PreResponse, eos]
|
288 |
chat_sep = eos
|
289 |
humanstr = promptA
|
290 |
botstr = PreResponse
|
291 |
+
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
292 |
+
PromptType.wizard_mega.name]:
|
293 |
preprompt = ''
|
294 |
start = ''
|
295 |
promptB = promptA = '%s%s' % (preprompt, start)
|
|
|
304 |
chat_sep = '\n'
|
305 |
humanstr = PreInstruct
|
306 |
botstr = PreResponse
|
307 |
+
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
308 |
+
PromptType.instruct_vicuna2.name]:
|
309 |
+
promptA = promptB = "" if not (
|
310 |
+
chat and reduced) else ''
|
311 |
+
|
312 |
+
PreInstruct = """
|
313 |
+
HUMAN:
|
314 |
+
"""
|
315 |
+
|
316 |
+
PreInput = None
|
317 |
+
|
318 |
+
PreResponse = """
|
319 |
+
ASSISTANT:
|
320 |
+
"""
|
321 |
+
terminate_response = [
|
322 |
+
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
323 |
+
chat_sep = '\n'
|
324 |
+
humanstr = PreInstruct
|
325 |
+
botstr = PreResponse
|
326 |
+
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
327 |
+
PromptType.instruct_vicuna3.name]:
|
328 |
+
promptA = promptB = "" if not (
|
329 |
+
chat and reduced) else ''
|
330 |
+
|
331 |
+
PreInstruct = """
|
332 |
+
### User:
|
333 |
+
"""
|
334 |
+
|
335 |
+
PreInput = None
|
336 |
+
|
337 |
+
PreResponse = """
|
338 |
+
### Assistant:
|
339 |
+
"""
|
340 |
+
terminate_response = [
|
341 |
+
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
342 |
+
chat_sep = '\n'
|
343 |
+
humanstr = PreInstruct
|
344 |
+
botstr = PreResponse
|
345 |
+
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
346 |
+
PromptType.wizard2.name]:
|
347 |
+
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
348 |
+
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
|
349 |
+
start = ''
|
350 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
351 |
+
PreInstruct = """
|
352 |
+
### Instruction:
|
353 |
+
"""
|
354 |
+
PreInput = None
|
355 |
+
PreResponse = """
|
356 |
+
### Response:
|
357 |
+
"""
|
358 |
+
terminate_response = [PreResponse]
|
359 |
+
chat_sep = '\n'
|
360 |
+
humanstr = PreInstruct
|
361 |
+
botstr = PreResponse
|
362 |
+
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
363 |
+
PromptType.wizard3.name]:
|
364 |
+
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
365 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
366 |
+
start = ''
|
367 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
368 |
+
PreInstruct = """USER: """
|
369 |
+
PreInput = None
|
370 |
+
PreResponse = """ASSISTANT: """
|
371 |
+
terminate_response = [PreResponse]
|
372 |
+
chat_sep = '\n'
|
373 |
+
humanstr = PreInstruct
|
374 |
+
botstr = PreResponse
|
375 |
+
|
376 |
else:
|
377 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
378 |
|
|
|
509 |
multi_output = len(outputs) > 1
|
510 |
|
511 |
for oi, output in enumerate(outputs):
|
512 |
+
if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
|
513 |
output = clean_response(output)
|
514 |
elif prompt is None:
|
515 |
# then use most basic parsing like pipeline
|
requirements.txt
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
datasets==2.12.0
|
3 |
sentencepiece==0.1.97
|
4 |
-
accelerate==0.18.0
|
5 |
gradio==3.31.0
|
6 |
huggingface_hub==0.14.1
|
7 |
appdirs==1.4.4
|
@@ -18,8 +17,9 @@ numpy==1.24.2
|
|
18 |
pandas==2.0.0
|
19 |
matplotlib==3.7.1
|
20 |
loralib==0.1.1
|
21 |
-
bitsandbytes==0.
|
22 |
-
|
|
|
23 |
transformers==4.28.1
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
@@ -50,18 +50,15 @@ pypandoc_binary==1.11
|
|
50 |
openpyxl==3.1.2
|
51 |
lm_dataformat==0.0.20
|
52 |
bioc==2.0
|
53 |
-
# To install with constraints
|
54 |
-
# grep -v '#\|peft' requirements.txt > req_constraints.txt ; pip install -r requirements_optional_langchain.txt -c req_constraints.txt
|
55 |
|
|
|
|
|
56 |
# optional for chat with PDF
|
57 |
-
langchain==0.0.
|
58 |
pypdf==3.8.1
|
59 |
tiktoken==0.3.3
|
60 |
# avoid textract, requires old six
|
61 |
#textract==1.6.5
|
62 |
-
# choose:
|
63 |
-
#faiss-cpu
|
64 |
-
faiss-gpu==1.7.2
|
65 |
|
66 |
# for HF embeddings
|
67 |
sentence_transformers==2.2.2
|
@@ -69,7 +66,7 @@ sentence_transformers==2.2.2
|
|
69 |
openai==0.27.6
|
70 |
|
71 |
# local vector db
|
72 |
-
chromadb==0.3.
|
73 |
# server vector db
|
74 |
#pymilvus==2.2.8
|
75 |
|
@@ -92,8 +89,12 @@ requests_file==1.5.1
|
|
92 |
tabulate==0.9.0
|
93 |
# FYI pandoc already part of requirements.txt
|
94 |
|
95 |
-
|
|
|
96 |
|
97 |
# to check licenses
|
98 |
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
99 |
pip-licenses==4.3.0
|
|
|
|
|
|
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
datasets==2.12.0
|
3 |
sentencepiece==0.1.97
|
|
|
4 |
gradio==3.31.0
|
5 |
huggingface_hub==0.14.1
|
6 |
appdirs==1.4.4
|
|
|
17 |
pandas==2.0.0
|
18 |
matplotlib==3.7.1
|
19 |
loralib==0.1.1
|
20 |
+
bitsandbytes==0.39.0
|
21 |
+
accelerate==0.19.0
|
22 |
+
git+https://github.com/huggingface/peft.git@3714aa2fff158fdfa637b2b65952580801d890b2
|
23 |
transformers==4.28.1
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
|
|
50 |
openpyxl==3.1.2
|
51 |
lm_dataformat==0.0.20
|
52 |
bioc==2.0
|
|
|
|
|
53 |
|
54 |
+
# falcon
|
55 |
+
einops==0.6.1
|
56 |
# optional for chat with PDF
|
57 |
+
langchain==0.0.183
|
58 |
pypdf==3.8.1
|
59 |
tiktoken==0.3.3
|
60 |
# avoid textract, requires old six
|
61 |
#textract==1.6.5
|
|
|
|
|
|
|
62 |
|
63 |
# for HF embeddings
|
64 |
sentence_transformers==2.2.2
|
|
|
66 |
openai==0.27.6
|
67 |
|
68 |
# local vector db
|
69 |
+
chromadb==0.3.25
|
70 |
# server vector db
|
71 |
#pymilvus==2.2.8
|
72 |
|
|
|
89 |
tabulate==0.9.0
|
90 |
# FYI pandoc already part of requirements.txt
|
91 |
|
92 |
+
# JSONLoader, but makes some trouble for some users
|
93 |
+
# jq==1.4.1
|
94 |
|
95 |
# to check licenses
|
96 |
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
97 |
pip-licenses==4.3.0
|
98 |
+
|
99 |
+
# weaviate vector db
|
100 |
+
weaviate-client==3.19.2
|
stopping.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import torch
|
2 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
3 |
|
|
|
|
|
4 |
|
5 |
class StoppingCriteriaSub(StoppingCriteria):
|
6 |
|
@@ -24,14 +26,14 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
24 |
|
25 |
|
26 |
def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
|
27 |
-
if prompt_type in [
|
28 |
-
if prompt_type ==
|
29 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
30 |
# stopping only starts once output is beyond prompt
|
31 |
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
32 |
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
33 |
encounters = [1, 2]
|
34 |
-
elif prompt_type ==
|
35 |
# even below is not enough, generic strings and many ways to encode
|
36 |
stop_words = [
|
37 |
'### Human:',
|
@@ -58,7 +60,7 @@ def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:")
|
|
58 |
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
59 |
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
60 |
# avoid padding in front of tokens
|
61 |
-
if tokenizer.
|
62 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
63 |
# handle fake \n added
|
64 |
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
|
|
1 |
import torch
|
2 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
3 |
|
4 |
+
from prompter import PromptType
|
5 |
+
|
6 |
|
7 |
class StoppingCriteriaSub(StoppingCriteria):
|
8 |
|
|
|
26 |
|
27 |
|
28 |
def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
|
29 |
+
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
30 |
+
if prompt_type == PromptType.human_bot.name:
|
31 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
32 |
# stopping only starts once output is beyond prompt
|
33 |
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
34 |
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
35 |
encounters = [1, 2]
|
36 |
+
elif prompt_type == PromptType.instruct_vicuna.name:
|
37 |
# even below is not enough, generic strings and many ways to encode
|
38 |
stop_words = [
|
39 |
'### Human:',
|
|
|
60 |
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
61 |
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
62 |
# avoid padding in front of tokens
|
63 |
+
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
|
64 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
65 |
# handle fake \n added
|
66 |
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import contextlib
|
2 |
import functools
|
3 |
import hashlib
|
|
|
4 |
import os
|
5 |
import gc
|
6 |
import pathlib
|
@@ -16,6 +17,8 @@ from datetime import datetime
|
|
16 |
import filelock
|
17 |
import requests, uuid
|
18 |
from typing import Tuple, Callable, Dict
|
|
|
|
|
19 |
from concurrent.futures import ProcessPoolExecutor
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
@@ -371,18 +374,15 @@ def sanitize_filename(name):
|
|
371 |
return name
|
372 |
|
373 |
|
374 |
-
def
|
375 |
-
path = args[0]
|
376 |
-
assert not os.path.samefile(path, "./tmp"), "Should not be trying to remove entire data directory: %s" % str(path)
|
377 |
-
# print("Removing path %s" % args[0]) # for debugging
|
378 |
return shutil.rmtree(*args, **kwargs)
|
379 |
|
380 |
|
381 |
-
def
|
382 |
try:
|
383 |
if path is not None and os.path.exists(path):
|
384 |
if os.path.isdir(path):
|
385 |
-
|
386 |
else:
|
387 |
with contextlib.suppress(FileNotFoundError):
|
388 |
os.remove(path)
|
@@ -408,7 +408,7 @@ def atomic_move_simple(src, dst):
|
|
408 |
shutil.move(src, dst)
|
409 |
except (shutil.Error, FileExistsError):
|
410 |
pass
|
411 |
-
|
412 |
|
413 |
|
414 |
def download_simple(url, dest=None, print_func=None):
|
@@ -481,7 +481,7 @@ def download(url, dest=None, dest_path=None):
|
|
481 |
shutil.move(dest_tmp, dest)
|
482 |
except FileExistsError:
|
483 |
pass
|
484 |
-
|
485 |
return dest
|
486 |
|
487 |
|
@@ -766,3 +766,78 @@ def call_subprocess_onetask(func, args=None, kwargs=None):
|
|
766 |
with ProcessPoolExecutor(max_workers=1) as executor:
|
767 |
future = executor.submit(_traced_func, *args, **kwargs)
|
768 |
return future.result()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import contextlib
|
2 |
import functools
|
3 |
import hashlib
|
4 |
+
import inspect
|
5 |
import os
|
6 |
import gc
|
7 |
import pathlib
|
|
|
17 |
import filelock
|
18 |
import requests, uuid
|
19 |
from typing import Tuple, Callable, Dict
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
from joblib import Parallel
|
22 |
from concurrent.futures import ProcessPoolExecutor
|
23 |
import numpy as np
|
24 |
import pandas as pd
|
|
|
374 |
return name
|
375 |
|
376 |
|
377 |
+
def shutil_rmtree(*args, **kwargs):
|
|
|
|
|
|
|
378 |
return shutil.rmtree(*args, **kwargs)
|
379 |
|
380 |
|
381 |
+
def remove(path: str):
|
382 |
try:
|
383 |
if path is not None and os.path.exists(path):
|
384 |
if os.path.isdir(path):
|
385 |
+
shutil_rmtree(path, ignore_errors=True)
|
386 |
else:
|
387 |
with contextlib.suppress(FileNotFoundError):
|
388 |
os.remove(path)
|
|
|
408 |
shutil.move(src, dst)
|
409 |
except (shutil.Error, FileExistsError):
|
410 |
pass
|
411 |
+
remove(src)
|
412 |
|
413 |
|
414 |
def download_simple(url, dest=None, print_func=None):
|
|
|
481 |
shutil.move(dest_tmp, dest)
|
482 |
except FileExistsError:
|
483 |
pass
|
484 |
+
remove(dest_tmp)
|
485 |
return dest
|
486 |
|
487 |
|
|
|
766 |
with ProcessPoolExecutor(max_workers=1) as executor:
|
767 |
future = executor.submit(_traced_func, *args, **kwargs)
|
768 |
return future.result()
|
769 |
+
|
770 |
+
|
771 |
+
class ProgressParallel(Parallel):
|
772 |
+
def __init__(self, use_tqdm=True, total=None, *args, **kwargs):
|
773 |
+
self._use_tqdm = use_tqdm
|
774 |
+
self._total = total
|
775 |
+
super().__init__(*args, **kwargs)
|
776 |
+
|
777 |
+
def __call__(self, *args, **kwargs):
|
778 |
+
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
|
779 |
+
return Parallel.__call__(self, *args, **kwargs)
|
780 |
+
|
781 |
+
def print_progress(self):
|
782 |
+
if self._total is None:
|
783 |
+
self._pbar.total = self.n_dispatched_tasks
|
784 |
+
self._pbar.n = self.n_completed_tasks
|
785 |
+
self._pbar.refresh()
|
786 |
+
|
787 |
+
|
788 |
+
def get_kwargs(func, exclude_names=None, **kwargs):
|
789 |
+
func_names = list(inspect.signature(func).parameters)
|
790 |
+
missing_kwargs = [x for x in func_names if x not in kwargs]
|
791 |
+
if exclude_names:
|
792 |
+
for k in exclude_names:
|
793 |
+
if k in missing_kwargs:
|
794 |
+
missing_kwargs.remove(k)
|
795 |
+
if k in func_names:
|
796 |
+
func_names.remove(k)
|
797 |
+
assert not missing_kwargs, "Missing %s" % missing_kwargs
|
798 |
+
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
799 |
+
return kwargs
|
800 |
+
|
801 |
+
|
802 |
+
import pkg_resources
|
803 |
+
have_faiss = False
|
804 |
+
|
805 |
+
try:
|
806 |
+
assert pkg_resources.get_distribution('faiss') is not None
|
807 |
+
have_faiss = True
|
808 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
809 |
+
pass
|
810 |
+
try:
|
811 |
+
assert pkg_resources.get_distribution('faiss_gpu') is not None
|
812 |
+
have_faiss = True
|
813 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
814 |
+
pass
|
815 |
+
try:
|
816 |
+
assert pkg_resources.get_distribution('faiss_cpu') is not None
|
817 |
+
have_faiss = True
|
818 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
819 |
+
pass
|
820 |
+
|
821 |
+
|
822 |
+
def hash_file(file):
|
823 |
+
try:
|
824 |
+
import hashlib
|
825 |
+
|
826 |
+
# BUF_SIZE is totally arbitrary, change for your app!
|
827 |
+
BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
|
828 |
+
|
829 |
+
md5 = hashlib.md5()
|
830 |
+
#sha1 = hashlib.sha1()
|
831 |
+
|
832 |
+
with open(file, 'rb') as f:
|
833 |
+
while True:
|
834 |
+
data = f.read(BUF_SIZE)
|
835 |
+
if not data:
|
836 |
+
break
|
837 |
+
md5.update(data)
|
838 |
+
#sha1.update(data)
|
839 |
+
except BaseException as e:
|
840 |
+
print("Cannot hash %s due to %s" % (file, str(e)))
|
841 |
+
traceback.print_exc()
|
842 |
+
md5 = None
|
843 |
+
return md5.hexdigest()
|