Spaces:
Running
Running
pseudotensor
commited on
Commit
•
cf9ad1a
1
Parent(s):
1c0f538
Update with h2oGPT hash e7d4914948ac2b9a5a82f1cc82556197b261cb46
Browse files- app.py +1 -1
- client_test.py +22 -14
- enums.py +16 -1
- evaluate_params.py +47 -0
- gen.py +0 -0
- gpt4all_llm.py +9 -0
- gpt_langchain.py +156 -52
- gradio_runner.py +143 -41
- gradio_utils/__pycache__/grclient.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/prompt_form.cpython-310.pyc +0 -0
- gradio_utils/prompt_form.py +7 -3
- h2oai_pipeline.py +1 -0
- prompter.py +4 -2
app.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
|
|
|
1 |
+
gen.py
|
client_test.py
CHANGED
@@ -12,13 +12,13 @@ Currently, this will force model to be on a single GPU.
|
|
12 |
|
13 |
Then run this client as:
|
14 |
|
15 |
-
python client_test.py
|
16 |
|
17 |
|
18 |
|
19 |
For HF spaces:
|
20 |
|
21 |
-
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
|
22 |
|
23 |
Result:
|
24 |
|
@@ -28,7 +28,7 @@ Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
|
|
28 |
|
29 |
For demo:
|
30 |
|
31 |
-
HOST="https://gpt.h2o.ai" python client_test.py
|
32 |
|
33 |
Result:
|
34 |
|
@@ -48,7 +48,7 @@ import markdown # pip install markdown
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
-
from enums import DocumentChoices
|
52 |
|
53 |
debug = False
|
54 |
|
@@ -67,7 +67,9 @@ def get_client(serialize=True):
|
|
67 |
def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
-
langchain_mode='Disabled'
|
|
|
|
|
71 |
from collections import OrderedDict
|
72 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
73 |
iinput='', # only for chat=True
|
@@ -76,7 +78,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
76 |
# but leave stream_output=False for simple input/output mode
|
77 |
stream_output=stream_output,
|
78 |
prompt_type=prompt_type,
|
79 |
-
prompt_dict=
|
80 |
temperature=0.1,
|
81 |
top_p=0.75,
|
82 |
top_k=40,
|
@@ -92,12 +94,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
92 |
instruction_nochat=prompt if not chat else '',
|
93 |
iinput_nochat='', # only for chat=False
|
94 |
langchain_mode=langchain_mode,
|
|
|
95 |
top_k_docs=top_k_docs,
|
96 |
chunk=True,
|
97 |
chunk_size=512,
|
98 |
document_choice=[DocumentChoices.All_Relevant.name],
|
99 |
)
|
100 |
-
from
|
101 |
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
102 |
if chat:
|
103 |
# add chatbot output on end. Assumes serialize=False
|
@@ -198,6 +201,7 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
198 |
instruction_nochat=prompt,
|
199 |
iinput_nochat='',
|
200 |
langchain_mode='Disabled',
|
|
|
201 |
top_k_docs=4,
|
202 |
document_choice=['All'],
|
203 |
)
|
@@ -219,21 +223,24 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
219 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
220 |
def test_client_chat(prompt_type='human_bot'):
|
221 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
222 |
-
langchain_mode='Disabled')
|
223 |
|
224 |
|
225 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
226 |
def test_client_chat_stream(prompt_type='human_bot'):
|
227 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
228 |
stream_output=True, max_new_tokens=512,
|
229 |
-
langchain_mode='Disabled')
|
230 |
|
231 |
|
232 |
-
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode
|
|
|
233 |
client = get_client(serialize=False)
|
234 |
|
235 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
236 |
-
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode
|
|
|
|
|
237 |
return run_client(client, prompt, args, kwargs)
|
238 |
|
239 |
|
@@ -276,14 +283,15 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
276 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
277 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
278 |
stream_output=True, max_new_tokens=512,
|
279 |
-
langchain_mode='Disabled')
|
280 |
|
281 |
|
282 |
-
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
|
283 |
client = get_client(serialize=False)
|
284 |
|
285 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
286 |
-
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode
|
|
|
287 |
return run_client_gen(client, prompt, args, kwargs)
|
288 |
|
289 |
|
|
|
12 |
|
13 |
Then run this client as:
|
14 |
|
15 |
+
python src/client_test.py
|
16 |
|
17 |
|
18 |
|
19 |
For HF spaces:
|
20 |
|
21 |
+
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
|
22 |
|
23 |
Result:
|
24 |
|
|
|
28 |
|
29 |
For demo:
|
30 |
|
31 |
+
HOST="https://gpt.h2o.ai" python src/client_test.py
|
32 |
|
33 |
Result:
|
34 |
|
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
+
from enums import DocumentChoices, LangChainAction
|
52 |
|
53 |
debug = False
|
54 |
|
|
|
67 |
def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
+
langchain_mode='Disabled',
|
71 |
+
langchain_action=LangChainAction.QUERY.value,
|
72 |
+
prompt_dict=None):
|
73 |
from collections import OrderedDict
|
74 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
75 |
iinput='', # only for chat=True
|
|
|
78 |
# but leave stream_output=False for simple input/output mode
|
79 |
stream_output=stream_output,
|
80 |
prompt_type=prompt_type,
|
81 |
+
prompt_dict=prompt_dict,
|
82 |
temperature=0.1,
|
83 |
top_p=0.75,
|
84 |
top_k=40,
|
|
|
94 |
instruction_nochat=prompt if not chat else '',
|
95 |
iinput_nochat='', # only for chat=False
|
96 |
langchain_mode=langchain_mode,
|
97 |
+
langchain_action=langchain_action,
|
98 |
top_k_docs=top_k_docs,
|
99 |
chunk=True,
|
100 |
chunk_size=512,
|
101 |
document_choice=[DocumentChoices.All_Relevant.name],
|
102 |
)
|
103 |
+
from evaluate_params import eval_func_param_names
|
104 |
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
105 |
if chat:
|
106 |
# add chatbot output on end. Assumes serialize=False
|
|
|
201 |
instruction_nochat=prompt,
|
202 |
iinput_nochat='',
|
203 |
langchain_mode='Disabled',
|
204 |
+
langchain_action=LangChainAction.QUERY.value,
|
205 |
top_k_docs=4,
|
206 |
document_choice=['All'],
|
207 |
)
|
|
|
223 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
224 |
def test_client_chat(prompt_type='human_bot'):
|
225 |
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
226 |
+
langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
|
227 |
|
228 |
|
229 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
230 |
def test_client_chat_stream(prompt_type='human_bot'):
|
231 |
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
232 |
stream_output=True, max_new_tokens=512,
|
233 |
+
langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
|
234 |
|
235 |
|
236 |
+
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action,
|
237 |
+
prompt_dict=None):
|
238 |
client = get_client(serialize=False)
|
239 |
|
240 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
241 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
242 |
+
langchain_action=langchain_action,
|
243 |
+
prompt_dict=prompt_dict)
|
244 |
return run_client(client, prompt, args, kwargs)
|
245 |
|
246 |
|
|
|
283 |
def test_client_nochat_stream(prompt_type='human_bot'):
|
284 |
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
285 |
stream_output=True, max_new_tokens=512,
|
286 |
+
langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
|
287 |
|
288 |
|
289 |
+
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action):
|
290 |
client = get_client(serialize=False)
|
291 |
|
292 |
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
293 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
294 |
+
langchain_action=langchain_action)
|
295 |
return run_client_gen(client, prompt, args, kwargs)
|
296 |
|
297 |
|
enums.py
CHANGED
@@ -37,6 +37,9 @@ class DocumentChoices(Enum):
|
|
37 |
Just_LLM = 3
|
38 |
|
39 |
|
|
|
|
|
|
|
40 |
class LangChainMode(Enum):
|
41 |
"""LangChain mode"""
|
42 |
|
@@ -52,10 +55,22 @@ class LangChainMode(Enum):
|
|
52 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
56 |
|
57 |
|
58 |
-
# from site-packages/langchain/llms/openai.py
|
|
|
59 |
model_token_mapping = {
|
60 |
"gpt-4": 8192,
|
61 |
"gpt-4-0314": 8192,
|
|
|
37 |
Just_LLM = 3
|
38 |
|
39 |
|
40 |
+
non_query_commands = [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]
|
41 |
+
|
42 |
+
|
43 |
class LangChainMode(Enum):
|
44 |
"""LangChain mode"""
|
45 |
|
|
|
55 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
56 |
|
57 |
|
58 |
+
class LangChainAction(Enum):
|
59 |
+
"""LangChain action"""
|
60 |
+
|
61 |
+
QUERY = "Query"
|
62 |
+
# WIP:
|
63 |
+
#SUMMARIZE_MAP = "Summarize_map_reduce"
|
64 |
+
SUMMARIZE_MAP = "Summarize"
|
65 |
+
SUMMARIZE_ALL = "Summarize_all"
|
66 |
+
SUMMARIZE_REFINE = "Summarize_refine"
|
67 |
+
|
68 |
+
|
69 |
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
70 |
|
71 |
|
72 |
+
# from site-packages/langchain/llms/openai.py
|
73 |
+
# but needed since ChatOpenAI doesn't have this information
|
74 |
model_token_mapping = {
|
75 |
"gpt-4": 8192,
|
76 |
"gpt-4-0314": 8192,
|
evaluate_params.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
no_default_param_names = [
|
2 |
+
'instruction',
|
3 |
+
'iinput',
|
4 |
+
'context',
|
5 |
+
'instruction_nochat',
|
6 |
+
'iinput_nochat',
|
7 |
+
]
|
8 |
+
|
9 |
+
gen_hyper = ['temperature',
|
10 |
+
'top_p',
|
11 |
+
'top_k',
|
12 |
+
'num_beams',
|
13 |
+
'max_new_tokens',
|
14 |
+
'min_new_tokens',
|
15 |
+
'early_stopping',
|
16 |
+
'max_time',
|
17 |
+
'repetition_penalty',
|
18 |
+
'num_return_sequences',
|
19 |
+
'do_sample',
|
20 |
+
]
|
21 |
+
|
22 |
+
eval_func_param_names = ['instruction',
|
23 |
+
'iinput',
|
24 |
+
'context',
|
25 |
+
'stream_output',
|
26 |
+
'prompt_type',
|
27 |
+
'prompt_dict'] + \
|
28 |
+
gen_hyper + \
|
29 |
+
['chat',
|
30 |
+
'instruction_nochat',
|
31 |
+
'iinput_nochat',
|
32 |
+
'langchain_mode',
|
33 |
+
'langchain_action',
|
34 |
+
'top_k_docs',
|
35 |
+
'chunk',
|
36 |
+
'chunk_size',
|
37 |
+
'document_choice',
|
38 |
+
]
|
39 |
+
|
40 |
+
# form evaluate defaults for submit_nochat_api
|
41 |
+
eval_func_param_names_defaults = eval_func_param_names.copy()
|
42 |
+
for k in no_default_param_names:
|
43 |
+
if k in eval_func_param_names_defaults:
|
44 |
+
eval_func_param_names_defaults.remove(k)
|
45 |
+
|
46 |
+
|
47 |
+
eval_extra_columns = ['prompt', 'response', 'score']
|
gen.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gpt4all_llm.py
CHANGED
@@ -19,6 +19,15 @@ def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
|
19 |
n_ctx=2048 - 256)
|
20 |
env_gpt4all_file = ".env_gpt4all"
|
21 |
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
if base_model == "llama":
|
24 |
if 'model_path_llama' not in model_kwargs:
|
|
|
19 |
n_ctx=2048 - 256)
|
20 |
env_gpt4all_file = ".env_gpt4all"
|
21 |
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
22 |
+
# make int or float if can to satisfy types for class
|
23 |
+
for k, v in model_kwargs.items():
|
24 |
+
try:
|
25 |
+
if float(v) == int(v):
|
26 |
+
model_kwargs[k] = int(v)
|
27 |
+
else:
|
28 |
+
model_kwargs[k] = float(v)
|
29 |
+
except:
|
30 |
+
pass
|
31 |
|
32 |
if base_model == "llama":
|
33 |
if 'model_path_llama' not in model_kwargs:
|
gpt_langchain.py
CHANGED
@@ -23,8 +23,10 @@ from langchain.callbacks import streaming_stdout
|
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
-
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
|
27 |
-
|
|
|
|
|
28 |
from prompter import non_hf_types, PromptType, Prompter
|
29 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
30 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
|
@@ -43,7 +45,8 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
|
43 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
44 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
45 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
46 |
-
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
|
|
|
47 |
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
|
48 |
from langchain.chains.question_answering import load_qa_chain
|
49 |
from langchain.docstore.document import Document
|
@@ -351,6 +354,7 @@ class GradioInference(LLM):
|
|
351 |
stream_output = self.stream
|
352 |
gr_client = self.client
|
353 |
client_langchain_mode = 'Disabled'
|
|
|
354 |
top_k_docs = 1
|
355 |
chunk = True
|
356 |
chunk_size = 512
|
@@ -379,6 +383,7 @@ class GradioInference(LLM):
|
|
379 |
instruction_nochat=prompt if not self.chat_client else '',
|
380 |
iinput_nochat='', # only for chat=False
|
381 |
langchain_mode=client_langchain_mode,
|
|
|
382 |
top_k_docs=top_k_docs,
|
383 |
chunk=chunk,
|
384 |
chunk_size=chunk_size,
|
@@ -637,6 +642,7 @@ def get_llm(use_openai_model=False,
|
|
637 |
callbacks = [StreamingGradioCallbackHandler()]
|
638 |
assert prompter is not None
|
639 |
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
|
|
640 |
|
641 |
if gr_client:
|
642 |
chat_client = False
|
@@ -744,7 +750,7 @@ def get_llm(use_openai_model=False,
|
|
744 |
|
745 |
if stream_output:
|
746 |
skip_prompt = False
|
747 |
-
from
|
748 |
decoder_kwargs = {}
|
749 |
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
|
750 |
gen_kwargs.update(dict(streamer=streamer))
|
@@ -944,14 +950,16 @@ have_playwright = False
|
|
944 |
|
945 |
image_types = ["png", "jpg", "jpeg"]
|
946 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
947 |
-
"md",
|
|
|
948 |
"enex", "eml", "epub", "odt", "pptx", "ppt",
|
949 |
"zip", "urls",
|
|
|
950 |
]
|
951 |
# "msg", GPL3
|
952 |
|
953 |
if have_libreoffice:
|
954 |
-
non_image_types.extend(["docx", "doc"])
|
955 |
|
956 |
file_types = non_image_types + image_types
|
957 |
|
@@ -961,7 +969,7 @@ def add_meta(docs1, file):
|
|
961 |
hashid = hash_file(file)
|
962 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
963 |
docs1 = [docs1]
|
964 |
-
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
|
965 |
|
966 |
|
967 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
@@ -1038,6 +1046,10 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1038 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1039 |
add_meta(docs1, file)
|
1040 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
|
|
|
|
|
|
|
|
1041 |
elif file.lower().endswith('.odt'):
|
1042 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
1043 |
add_meta(docs1, file)
|
@@ -1171,7 +1183,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1171 |
# so just extract in path where
|
1172 |
zip_ref.extractall(base_path)
|
1173 |
# recurse
|
1174 |
-
doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception)
|
1175 |
else:
|
1176 |
raise RuntimeError("No file handler for %s" % os.path.basename(file))
|
1177 |
|
@@ -1758,6 +1770,8 @@ def run_qa_db(**kwargs):
|
|
1758 |
|
1759 |
|
1760 |
def _run_qa_db(query=None,
|
|
|
|
|
1761 |
use_openai_model=False, use_openai_embedding=False,
|
1762 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1763 |
user_path=None,
|
@@ -1787,6 +1801,7 @@ def _run_qa_db(query=None,
|
|
1787 |
repetition_penalty=1.0,
|
1788 |
num_return_sequences=1,
|
1789 |
langchain_mode=None,
|
|
|
1790 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1791 |
n_jobs=-1,
|
1792 |
verbose=False,
|
@@ -1803,7 +1818,7 @@ def _run_qa_db(query=None,
|
|
1803 |
:param use_openai_embedding:
|
1804 |
:param first_para:
|
1805 |
:param text_limit:
|
1806 |
-
:param
|
1807 |
:param chunk:
|
1808 |
:param chunk_size:
|
1809 |
:param user_path: user path to glob recursively from
|
@@ -1869,12 +1884,28 @@ def _run_qa_db(query=None,
|
|
1869 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1870 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1871 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1872 |
-
docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
|
1873 |
-
if cmd in
|
1874 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1875 |
yield formatted_doc_chunks, ''
|
1876 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1877 |
if chain is None and model_name not in non_hf_types:
|
|
|
1878 |
# can only return if HF type
|
1879 |
return
|
1880 |
|
@@ -1933,6 +1964,7 @@ def _run_qa_db(query=None,
|
|
1933 |
|
1934 |
|
1935 |
def get_similarity_chain(query=None,
|
|
|
1936 |
use_openai_model=False, use_openai_embedding=False,
|
1937 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1938 |
user_path=None,
|
@@ -1947,6 +1979,7 @@ def get_similarity_chain(query=None,
|
|
1947 |
load_db_if_exists=False,
|
1948 |
db=None,
|
1949 |
langchain_mode=None,
|
|
|
1950 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1951 |
n_jobs=-1,
|
1952 |
# beyond run_db_query:
|
@@ -1997,25 +2030,56 @@ def get_similarity_chain(query=None,
|
|
1997 |
db=db,
|
1998 |
n_jobs=n_jobs,
|
1999 |
verbose=verbose)
|
2000 |
-
|
2001 |
-
if
|
2002 |
-
|
2003 |
-
|
2004 |
-
|
2005 |
-
|
2006 |
-
|
2007 |
-
|
2008 |
-
|
2009 |
-
|
2010 |
-
|
2011 |
-
|
2012 |
-
|
2013 |
-
|
2014 |
-
|
2015 |
-
{context}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2016 |
\"\"\"
|
2017 |
-
%s
|
2018 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2019 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2020 |
use_template = True
|
2021 |
else:
|
@@ -2040,14 +2104,26 @@ def get_similarity_chain(query=None,
|
|
2040 |
if cmd == DocumentChoices.Just_LLM.name:
|
2041 |
docs = []
|
2042 |
scores = []
|
2043 |
-
elif cmd == DocumentChoices.Only_All_Sources.name:
|
2044 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2045 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2046 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2047 |
-
for result in zip(db_documents, db_metadatas)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2048 |
docs = [x[0] for x in docs_with_score]
|
2049 |
scores = [x[1] for x in docs_with_score]
|
|
|
2050 |
else:
|
|
|
|
|
2051 |
if top_k_docs == -1 or auto_reduce_chunks:
|
2052 |
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2053 |
top_k_docs_tokenize = 100
|
@@ -2120,6 +2196,7 @@ def get_similarity_chain(query=None,
|
|
2120 |
if reverse_docs:
|
2121 |
docs_with_score.reverse()
|
2122 |
# cut off so no high distance docs/sources considered
|
|
|
2123 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
2124 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
2125 |
if len(scores) > 0 and verbose:
|
@@ -2131,14 +2208,14 @@ def get_similarity_chain(query=None,
|
|
2131 |
|
2132 |
if not docs and use_context and model_name not in non_hf_types:
|
2133 |
# if HF type and have no docs, can bail out
|
2134 |
-
return docs, None, [], False
|
2135 |
|
2136 |
-
if cmd in
|
2137 |
# no LLM use
|
2138 |
-
return docs, None, [], False
|
2139 |
|
2140 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
2141 |
-
if os.path.isfile(common_words_file):
|
2142 |
df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
|
2143 |
import string
|
2144 |
reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
|
@@ -2155,25 +2232,47 @@ def get_similarity_chain(query=None,
|
|
2155 |
use_context = False
|
2156 |
template = template_if_no_docs
|
2157 |
|
2158 |
-
if
|
2159 |
-
|
2160 |
-
|
2161 |
-
|
2162 |
-
|
2163 |
-
|
2164 |
-
|
2165 |
-
|
2166 |
-
|
2167 |
-
|
2168 |
-
|
2169 |
-
|
2170 |
-
|
2171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2172 |
else:
|
2173 |
-
|
2174 |
|
2175 |
-
target
|
2176 |
-
return docs, target, scores, use_context
|
2177 |
|
2178 |
|
2179 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
@@ -2243,6 +2342,11 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
|
2243 |
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
|
2244 |
separators=separators)
|
2245 |
source_chunks = splitter.split_documents(sources)
|
|
|
|
|
|
|
|
|
|
|
2246 |
return source_chunks
|
2247 |
|
2248 |
|
|
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
+
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
|
27 |
+
LangChainAction, LangChainMode
|
28 |
+
from evaluate_params import gen_hyper
|
29 |
+
from gen import get_model, SEED
|
30 |
from prompter import non_hf_types, PromptType, Prompter
|
31 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
32 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
|
|
|
45 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
46 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
47 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
48 |
+
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \
|
49 |
+
UnstructuredExcelLoader
|
50 |
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
|
51 |
from langchain.chains.question_answering import load_qa_chain
|
52 |
from langchain.docstore.document import Document
|
|
|
354 |
stream_output = self.stream
|
355 |
gr_client = self.client
|
356 |
client_langchain_mode = 'Disabled'
|
357 |
+
client_langchain_action = LangChainAction.QUERY.value
|
358 |
top_k_docs = 1
|
359 |
chunk = True
|
360 |
chunk_size = 512
|
|
|
383 |
instruction_nochat=prompt if not self.chat_client else '',
|
384 |
iinput_nochat='', # only for chat=False
|
385 |
langchain_mode=client_langchain_mode,
|
386 |
+
langchain_action=client_langchain_action,
|
387 |
top_k_docs=top_k_docs,
|
388 |
chunk=chunk,
|
389 |
chunk_size=chunk_size,
|
|
|
642 |
callbacks = [StreamingGradioCallbackHandler()]
|
643 |
assert prompter is not None
|
644 |
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
645 |
+
stop_sequences = [x for x in stop_sequences if x]
|
646 |
|
647 |
if gr_client:
|
648 |
chat_client = False
|
|
|
750 |
|
751 |
if stream_output:
|
752 |
skip_prompt = False
|
753 |
+
from gen import H2OTextIteratorStreamer
|
754 |
decoder_kwargs = {}
|
755 |
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
|
756 |
gen_kwargs.update(dict(streamer=streamer))
|
|
|
950 |
|
951 |
image_types = ["png", "jpg", "jpeg"]
|
952 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
953 |
+
"md",
|
954 |
+
"html", "mhtml",
|
955 |
"enex", "eml", "epub", "odt", "pptx", "ppt",
|
956 |
"zip", "urls",
|
957 |
+
|
958 |
]
|
959 |
# "msg", GPL3
|
960 |
|
961 |
if have_libreoffice:
|
962 |
+
non_image_types.extend(["docx", "doc", "xls", "xlsx"])
|
963 |
|
964 |
file_types = non_image_types + image_types
|
965 |
|
|
|
969 |
hashid = hash_file(file)
|
970 |
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
971 |
docs1 = [docs1]
|
972 |
+
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for x in docs1]
|
973 |
|
974 |
|
975 |
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
|
1046 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1047 |
add_meta(docs1, file)
|
1048 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1049 |
+
elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
|
1050 |
+
docs1 = UnstructuredExcelLoader(file_path=file).load()
|
1051 |
+
add_meta(docs1, file)
|
1052 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1053 |
elif file.lower().endswith('.odt'):
|
1054 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
1055 |
add_meta(docs1, file)
|
|
|
1183 |
# so just extract in path where
|
1184 |
zip_ref.extractall(base_path)
|
1185 |
# recurse
|
1186 |
+
doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception, n_jobs=n_jobs)
|
1187 |
else:
|
1188 |
raise RuntimeError("No file handler for %s" % os.path.basename(file))
|
1189 |
|
|
|
1770 |
|
1771 |
|
1772 |
def _run_qa_db(query=None,
|
1773 |
+
iinput=None,
|
1774 |
+
context=None,
|
1775 |
use_openai_model=False, use_openai_embedding=False,
|
1776 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1777 |
user_path=None,
|
|
|
1801 |
repetition_penalty=1.0,
|
1802 |
num_return_sequences=1,
|
1803 |
langchain_mode=None,
|
1804 |
+
langchain_action=None,
|
1805 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1806 |
n_jobs=-1,
|
1807 |
verbose=False,
|
|
|
1818 |
:param use_openai_embedding:
|
1819 |
:param first_para:
|
1820 |
:param text_limit:
|
1821 |
+
:param top_k_docs:
|
1822 |
:param chunk:
|
1823 |
:param chunk_size:
|
1824 |
:param user_path: user path to glob recursively from
|
|
|
1884 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1885 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1886 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1887 |
+
docs, chain, scores, use_context, have_any_docs = get_similarity_chain(**sim_kwargs)
|
1888 |
+
if cmd in non_query_commands:
|
1889 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1890 |
yield formatted_doc_chunks, ''
|
1891 |
return
|
1892 |
+
if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
1893 |
+
LangChainAction.SUMMARIZE_ALL.value,
|
1894 |
+
LangChainAction.SUMMARIZE_REFINE.value]:
|
1895 |
+
ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
|
1896 |
+
extra = ''
|
1897 |
+
yield ret, extra
|
1898 |
+
return
|
1899 |
+
if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
|
1900 |
+
LangChainMode.CHAT_LLM.value,
|
1901 |
+
LangChainMode.LLM.value]:
|
1902 |
+
ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
|
1903 |
+
extra = ''
|
1904 |
+
yield ret, extra
|
1905 |
+
return
|
1906 |
+
|
1907 |
if chain is None and model_name not in non_hf_types:
|
1908 |
+
# here if no docs at all and not HF type
|
1909 |
# can only return if HF type
|
1910 |
return
|
1911 |
|
|
|
1964 |
|
1965 |
|
1966 |
def get_similarity_chain(query=None,
|
1967 |
+
iinput=None,
|
1968 |
use_openai_model=False, use_openai_embedding=False,
|
1969 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1970 |
user_path=None,
|
|
|
1979 |
load_db_if_exists=False,
|
1980 |
db=None,
|
1981 |
langchain_mode=None,
|
1982 |
+
langchain_action=None,
|
1983 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1984 |
n_jobs=-1,
|
1985 |
# beyond run_db_query:
|
|
|
2030 |
db=db,
|
2031 |
n_jobs=n_jobs,
|
2032 |
verbose=verbose)
|
2033 |
+
have_any_docs = db is not None
|
2034 |
+
if langchain_action == LangChainAction.QUERY.value:
|
2035 |
+
if iinput:
|
2036 |
+
query = "%s\n%s" % (query, iinput)
|
2037 |
+
|
2038 |
+
if 'falcon' in model_name:
|
2039 |
+
extra = "According to only the information in the document sources provided within the context above, "
|
2040 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
|
2041 |
+
elif inference_server in ['openai', 'openai_chat']:
|
2042 |
+
extra = "According to (primarily) the information in the document sources provided within context above, "
|
2043 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
|
2044 |
+
else:
|
2045 |
+
extra = ""
|
2046 |
+
prefix = ""
|
2047 |
+
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
|
2048 |
+
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2049 |
+
else:
|
2050 |
+
template = """%s
|
2051 |
+
\"\"\"
|
2052 |
+
{context}
|
2053 |
+
\"\"\"
|
2054 |
+
%s{question}""" % (prefix, extra)
|
2055 |
+
template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
|
2056 |
+
elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]:
|
2057 |
+
none = ['', '\n', None]
|
2058 |
+
if query in none and iinput in none:
|
2059 |
+
prompt_summary = "Using only the text above, write a condensed and concise summary:\n"
|
2060 |
+
elif query not in none:
|
2061 |
+
prompt_summary = "Focusing on %s, write a condensed and concise Summary:\n" % query
|
2062 |
+
elif iinput not in None:
|
2063 |
+
prompt_summary = iinput
|
2064 |
+
else:
|
2065 |
+
prompt_summary = "Focusing on %s, %s:\n" % (query, iinput)
|
2066 |
+
# don't auto reduce
|
2067 |
+
auto_reduce_chunks = False
|
2068 |
+
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
|
2069 |
+
fstring = '{text}'
|
2070 |
+
else:
|
2071 |
+
fstring = '{input_documents}'
|
2072 |
+
template = """In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text:
|
2073 |
\"\"\"
|
2074 |
+
%s
|
2075 |
+
\"\"\"\n%s""" % (fstring, prompt_summary)
|
2076 |
+
template_if_no_docs = "Exactly only say: There are no documents to summarize."
|
2077 |
+
elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]:
|
2078 |
+
template = '' # unused
|
2079 |
+
template_if_no_docs = '' # unused
|
2080 |
+
else:
|
2081 |
+
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2082 |
+
|
2083 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2084 |
use_template = True
|
2085 |
else:
|
|
|
2104 |
if cmd == DocumentChoices.Just_LLM.name:
|
2105 |
docs = []
|
2106 |
scores = []
|
2107 |
+
elif cmd == DocumentChoices.Only_All_Sources.name or query in [None, '', '\n']:
|
2108 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2109 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2110 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2111 |
+
for result in zip(db_documents, db_metadatas)]
|
2112 |
+
|
2113 |
+
# order documents
|
2114 |
+
doc_hashes = [x['doc_hash'] for x in db_metadatas]
|
2115 |
+
doc_chunk_ids = [x['chunk_id'] for x in db_metadatas]
|
2116 |
+
docs_with_score = [x for _, _, x in
|
2117 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
2118 |
+
]
|
2119 |
+
|
2120 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
2121 |
docs = [x[0] for x in docs_with_score]
|
2122 |
scores = [x[1] for x in docs_with_score]
|
2123 |
+
have_any_docs |= len(docs) > 0
|
2124 |
else:
|
2125 |
+
# FIXME: if langchain_action == LangChainAction.SUMMARIZE_MAP.value
|
2126 |
+
# if map_reduce, then no need to auto reduce chunks
|
2127 |
if top_k_docs == -1 or auto_reduce_chunks:
|
2128 |
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2129 |
top_k_docs_tokenize = 100
|
|
|
2196 |
if reverse_docs:
|
2197 |
docs_with_score.reverse()
|
2198 |
# cut off so no high distance docs/sources considered
|
2199 |
+
have_any_docs |= len(docs_with_score) > 0 # before cut
|
2200 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
2201 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
2202 |
if len(scores) > 0 and verbose:
|
|
|
2208 |
|
2209 |
if not docs and use_context and model_name not in non_hf_types:
|
2210 |
# if HF type and have no docs, can bail out
|
2211 |
+
return docs, None, [], False, have_any_docs
|
2212 |
|
2213 |
+
if cmd in non_query_commands:
|
2214 |
# no LLM use
|
2215 |
+
return docs, None, [], False, have_any_docs
|
2216 |
|
2217 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
2218 |
+
if os.path.isfile(common_words_file) and langchain_mode == LangChainAction.QUERY.value:
|
2219 |
df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
|
2220 |
import string
|
2221 |
reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
|
|
|
2232 |
use_context = False
|
2233 |
template = template_if_no_docs
|
2234 |
|
2235 |
+
if langchain_action == LangChainAction.QUERY.value:
|
2236 |
+
if use_template:
|
2237 |
+
# instruct-like, rather than few-shot prompt_type='plain' as default
|
2238 |
+
# but then sources confuse the model with how inserted among rest of text, so avoid
|
2239 |
+
prompt = PromptTemplate(
|
2240 |
+
# input_variables=["summaries", "question"],
|
2241 |
+
input_variables=["context", "question"],
|
2242 |
+
template=template,
|
2243 |
+
)
|
2244 |
+
chain = load_qa_chain(llm, prompt=prompt)
|
2245 |
+
else:
|
2246 |
+
# only if use_openai_model = True, unused normally except in testing
|
2247 |
+
chain = load_qa_with_sources_chain(llm)
|
2248 |
+
if not use_context:
|
2249 |
+
chain_kwargs = dict(input_documents=[], question=query)
|
2250 |
+
else:
|
2251 |
+
chain_kwargs = dict(input_documents=docs, question=query)
|
2252 |
+
target = wrapped_partial(chain, chain_kwargs)
|
2253 |
+
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
2254 |
+
LangChainAction.SUMMARIZE_REFINE,
|
2255 |
+
LangChainAction.SUMMARIZE_ALL.value]:
|
2256 |
+
from langchain.chains.summarize import load_summarize_chain
|
2257 |
+
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
|
2258 |
+
prompt = PromptTemplate(input_variables=["text"], template=template)
|
2259 |
+
chain = load_summarize_chain(llm, chain_type="map_reduce",
|
2260 |
+
map_prompt=prompt, combine_prompt=prompt, return_intermediate_steps=True)
|
2261 |
+
target = wrapped_partial(chain, {"input_documents": docs}) # , return_only_outputs=True)
|
2262 |
+
elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
|
2263 |
+
assert use_template
|
2264 |
+
prompt = PromptTemplate(input_variables=["text"], template=template)
|
2265 |
+
chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, return_intermediate_steps=True)
|
2266 |
+
target = wrapped_partial(chain)
|
2267 |
+
elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
|
2268 |
+
chain = load_summarize_chain(llm, chain_type="refine", return_intermediate_steps=True)
|
2269 |
+
target = wrapped_partial(chain)
|
2270 |
+
else:
|
2271 |
+
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2272 |
else:
|
2273 |
+
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2274 |
|
2275 |
+
return docs, target, scores, use_context, have_any_docs
|
|
|
2276 |
|
2277 |
|
2278 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
|
|
2342 |
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
|
2343 |
separators=separators)
|
2344 |
source_chunks = splitter.split_documents(sources)
|
2345 |
+
|
2346 |
+
# currently in order, but when pull from db won't be, so mark order and document by hash
|
2347 |
+
doc_hash = str(uuid.uuid4())[:10]
|
2348 |
+
[x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
|
2349 |
+
|
2350 |
return source_chunks
|
2351 |
|
2352 |
|
gradio_runner.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import copy
|
2 |
import functools
|
3 |
import inspect
|
@@ -49,16 +50,16 @@ def fix_pydantic_duplicate_validators_error():
|
|
49 |
|
50 |
fix_pydantic_duplicate_validators_error()
|
51 |
|
52 |
-
from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainMode
|
53 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
54 |
text_xsm
|
55 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
56 |
get_prompt
|
57 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
58 |
ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
|
59 |
-
from
|
60 |
-
|
61 |
-
|
62 |
|
63 |
from apscheduler.schedulers.background import BackgroundScheduler
|
64 |
|
@@ -99,6 +100,7 @@ def go_gradio(**kwargs):
|
|
99 |
dbs = kwargs['dbs']
|
100 |
db_type = kwargs['db_type']
|
101 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
|
|
102 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
103 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
104 |
enable_sources_list = kwargs['enable_sources_list']
|
@@ -213,7 +215,28 @@ def go_gradio(**kwargs):
|
|
213 |
'base_model') else no_model_msg
|
214 |
output_label0_model2 = no_model_msg
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
|
|
|
|
|
|
|
|
|
217 |
for k in no_default_param_names:
|
218 |
default_kwargs[k] = ''
|
219 |
|
@@ -239,7 +262,8 @@ def go_gradio(**kwargs):
|
|
239 |
model_options_state = gr.State([model_options])
|
240 |
lora_options_state = gr.State([lora_options])
|
241 |
server_options_state = gr.State([server_options])
|
242 |
-
|
|
|
243 |
chat_state = gr.State({})
|
244 |
# make user default first and default choice, dedup
|
245 |
docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
|
@@ -283,7 +307,7 @@ def go_gradio(**kwargs):
|
|
283 |
|
284 |
col_chat = gr.Column(visible=kwargs['chat'])
|
285 |
with col_chat:
|
286 |
-
instruction, submit, stop_btn = make_prompt_form(kwargs)
|
287 |
text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
|
288 |
**kwargs)
|
289 |
|
@@ -332,6 +356,12 @@ def go_gradio(**kwargs):
|
|
332 |
value=kwargs['langchain_mode'],
|
333 |
label="Data Collection of Sources",
|
334 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
336 |
with data_row2:
|
337 |
with gr.Column(scale=50):
|
@@ -726,6 +756,7 @@ def go_gradio(**kwargs):
|
|
726 |
caption_loader=caption_loader,
|
727 |
verbose=kwargs['verbose'],
|
728 |
user_path=kwargs['user_path'],
|
|
|
729 |
)
|
730 |
add_file_outputs = [fileup_output, langchain_mode, add_to_shared_db_btn, add_to_my_db_btn]
|
731 |
add_file_kwargs = dict(fn=update_user_db_func,
|
@@ -804,6 +835,7 @@ def go_gradio(**kwargs):
|
|
804 |
caption_loader=caption_loader,
|
805 |
verbose=kwargs['verbose'],
|
806 |
user_path=kwargs['user_path'],
|
|
|
807 |
)
|
808 |
|
809 |
add_my_file_outputs = [fileup_output, langchain_mode, my_db_state, add_to_shared_db_btn, add_to_my_db_btn]
|
@@ -920,19 +952,59 @@ def go_gradio(**kwargs):
|
|
920 |
for k in inputs_kwargs_list:
|
921 |
assert k in kwargs_evaluate, "Missing %s" % k
|
922 |
|
923 |
-
def
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
929 |
|
930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
**kwargs_evaluate)
|
932 |
-
fun2 = partial(
|
|
|
|
|
933 |
**kwargs_evaluate)
|
934 |
-
fun_with_dict_str = partial(
|
935 |
-
|
|
|
936 |
**kwargs_evaluate
|
937 |
)
|
938 |
|
@@ -1072,14 +1144,17 @@ def go_gradio(**kwargs):
|
|
1072 |
User that fills history for bot
|
1073 |
:param args:
|
1074 |
:param undo:
|
|
|
1075 |
:param sanitize_user_prompt:
|
1076 |
-
:param model2:
|
1077 |
:return:
|
1078 |
"""
|
1079 |
args_list = list(args)
|
1080 |
user_message = args_list[eval_func_param_names.index('instruction')] # chat only
|
1081 |
input1 = args_list[eval_func_param_names.index('iinput')] # chat only
|
1082 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
|
|
|
|
|
|
1083 |
if not prompt_type1:
|
1084 |
# shouldn't have to specify if CLI launched model
|
1085 |
prompt_type1 = kwargs['prompt_type']
|
@@ -1110,8 +1185,12 @@ def go_gradio(**kwargs):
|
|
1110 |
history[-1][1] = None
|
1111 |
return history
|
1112 |
if user_message1 in ['', None, '\n']:
|
1113 |
-
|
1114 |
-
|
|
|
|
|
|
|
|
|
1115 |
user_message1 = fix_text_for_gradio(user_message1)
|
1116 |
return history + [[user_message1, None]]
|
1117 |
|
@@ -1147,11 +1226,13 @@ def go_gradio(**kwargs):
|
|
1147 |
else:
|
1148 |
return 2000
|
1149 |
|
1150 |
-
def prep_bot(*args, retry=False):
|
1151 |
"""
|
1152 |
|
1153 |
:param args:
|
1154 |
:param retry:
|
|
|
|
|
1155 |
:return: last element is True if should run bot, False if should just yield history
|
1156 |
"""
|
1157 |
# don't deepcopy, can contain model itself
|
@@ -1159,12 +1240,16 @@ def go_gradio(**kwargs):
|
|
1159 |
model_state1 = args_list[-3]
|
1160 |
my_db_state1 = args_list[-2]
|
1161 |
history = args_list[-1]
|
1162 |
-
|
|
|
1163 |
|
1164 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1165 |
return history, None, None, None
|
1166 |
|
1167 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
|
|
|
|
|
|
1168 |
if not history:
|
1169 |
print("No history", flush=True)
|
1170 |
history = []
|
@@ -1175,22 +1260,23 @@ def go_gradio(**kwargs):
|
|
1175 |
instruction1 = history[-1][0]
|
1176 |
history[-1][1] = None
|
1177 |
elif not instruction1:
|
1178 |
-
|
1179 |
-
|
|
|
|
|
|
|
|
|
1180 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
1181 |
# reject submit button if already filled and not retrying
|
1182 |
# None when not filling with '' to keep client happy
|
1183 |
return history, None, None, None
|
1184 |
|
1185 |
# shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
|
1186 |
-
prompt_type1 =
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
prompt_dict1 = kwargs.get('prompt_dict', args_list[eval_func_param_names.index('prompt_dict')])
|
1192 |
-
args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 = \
|
1193 |
-
model_state1.get('prompt_dict', prompt_dict1)
|
1194 |
|
1195 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1196 |
model_max_length1 = get_model_max_length(model_state1)
|
@@ -1264,6 +1350,7 @@ def go_gradio(**kwargs):
|
|
1264 |
for res in get_response(fun1, history):
|
1265 |
yield res
|
1266 |
finally:
|
|
|
1267 |
clear_embeddings(langchain_mode1, my_db_state1)
|
1268 |
|
1269 |
def all_bot(*args, retry=False, model_states1=None):
|
@@ -1277,7 +1364,7 @@ def go_gradio(**kwargs):
|
|
1277 |
my_db_state1 = None # will be filled below by some bot
|
1278 |
try:
|
1279 |
gen_list = []
|
1280 |
-
for chatbot1, model_state1 in zip(chatbots, model_states1):
|
1281 |
args_list1 = args_list0.copy()
|
1282 |
args_list1.insert(-1, model_state1) # insert at -1 so is at -2
|
1283 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
@@ -1289,7 +1376,8 @@ def go_gradio(**kwargs):
|
|
1289 |
# so consistent with prep_bot()
|
1290 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1291 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1292 |
-
history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry
|
|
|
1293 |
gen1 = get_response(fun1, history)
|
1294 |
if stream_output1:
|
1295 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
@@ -1301,6 +1389,7 @@ def go_gradio(**kwargs):
|
|
1301 |
tgen0 = time.time()
|
1302 |
for res1 in itertools.zip_longest(*gen_list):
|
1303 |
if time.time() - tgen0 > max_time1:
|
|
|
1304 |
break
|
1305 |
|
1306 |
bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
|
@@ -1735,6 +1824,9 @@ def go_gradio(**kwargs):
|
|
1735 |
|
1736 |
def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
|
1737 |
infer_devices, gpu_id):
|
|
|
|
|
|
|
1738 |
# ensure old model removed from GPU memory
|
1739 |
if kwargs['debug']:
|
1740 |
print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
|
@@ -2161,6 +2253,15 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData',
|
|
2161 |
clear_torch_cache()
|
2162 |
|
2163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2164 |
def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
|
2165 |
user_path=None,
|
2166 |
use_openai_embedding=None,
|
@@ -2170,7 +2271,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2170 |
captions_model=None,
|
2171 |
enable_ocr=None,
|
2172 |
verbose=None,
|
2173 |
-
is_url=None, is_txt=None
|
|
|
2174 |
assert use_openai_embedding is not None
|
2175 |
assert hf_embedding_model is not None
|
2176 |
assert caption_loader is not None
|
@@ -2211,6 +2313,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2211 |
print("Adding %s" % file, flush=True)
|
2212 |
sources = path_to_docs(file if not is_url and not is_txt else None,
|
2213 |
verbose=verbose,
|
|
|
2214 |
chunk=chunk, chunk_size=chunk_size,
|
2215 |
url=file if is_url else None,
|
2216 |
text=file if is_txt else None,
|
@@ -2222,7 +2325,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2222 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2223 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2224 |
|
2225 |
-
|
|
|
2226 |
if langchain_mode == 'MyData':
|
2227 |
if db1[0] is not None:
|
2228 |
# then add
|
@@ -2235,18 +2339,14 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2235 |
# for production hit, when user gets clicky:
|
2236 |
assert len(db1) == 2, "Bad MyData db: %s" % db1
|
2237 |
# then create
|
2238 |
-
# assign fresh hash for this user session, so not shared
|
2239 |
# if added has to original state and didn't change, then would be shared db for all users
|
2240 |
-
db1[1] = str(uuid.uuid4())
|
2241 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
2242 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2243 |
db_type=db_type,
|
2244 |
persist_directory=persist_directory,
|
2245 |
langchain_mode=langchain_mode,
|
2246 |
hf_embedding_model=hf_embedding_model)
|
2247 |
-
if db is None:
|
2248 |
-
db1[1] = None
|
2249 |
-
else:
|
2250 |
db1[0] = db
|
2251 |
source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
|
2252 |
return None, langchain_mode, db1, x, y, source_files_added
|
@@ -2274,7 +2374,9 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
|
|
2274 |
|
2275 |
|
2276 |
def get_db(db1, langchain_mode, dbs=None):
|
2277 |
-
|
|
|
|
|
2278 |
if langchain_mode in ['wiki_full']:
|
2279 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2280 |
db = None
|
|
|
1 |
+
import ast
|
2 |
import copy
|
3 |
import functools
|
4 |
import inspect
|
|
|
50 |
|
51 |
fix_pydantic_duplicate_validators_error()
|
52 |
|
53 |
+
from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode
|
54 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
55 |
text_xsm
|
56 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
57 |
get_prompt
|
58 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
59 |
ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
|
60 |
+
from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
|
61 |
+
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
|
62 |
+
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
|
63 |
|
64 |
from apscheduler.schedulers.background import BackgroundScheduler
|
65 |
|
|
|
100 |
dbs = kwargs['dbs']
|
101 |
db_type = kwargs['db_type']
|
102 |
visible_langchain_modes = kwargs['visible_langchain_modes']
|
103 |
+
visible_langchain_actions = kwargs['visible_langchain_actions']
|
104 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
105 |
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
106 |
enable_sources_list = kwargs['enable_sources_list']
|
|
|
215 |
'base_model') else no_model_msg
|
216 |
output_label0_model2 = no_model_msg
|
217 |
|
218 |
+
def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0):
|
219 |
+
if not prompt_type1 or which_model != 0:
|
220 |
+
# keep prompt_type and prompt_dict in sync if possible
|
221 |
+
prompt_type1 = kwargs.get('prompt_type', prompt_type1)
|
222 |
+
prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
|
223 |
+
# prefer model specific prompt type instead of global one
|
224 |
+
if not prompt_type1 or which_model != 0:
|
225 |
+
prompt_type1 = model_state1.get('prompt_type', prompt_type1)
|
226 |
+
prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
|
227 |
+
|
228 |
+
if not prompt_dict1 or which_model != 0:
|
229 |
+
# if still not defined, try to get
|
230 |
+
prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
|
231 |
+
if not prompt_dict1 or which_model != 0:
|
232 |
+
prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
|
233 |
+
return prompt_type1, prompt_dict1
|
234 |
+
|
235 |
default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
|
236 |
+
# ensure prompt_type consistent with prep_bot(), so nochat API works same way
|
237 |
+
default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
|
238 |
+
update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
|
239 |
+
model_state1=model_state0, which_model=0)
|
240 |
for k in no_default_param_names:
|
241 |
default_kwargs[k] = ''
|
242 |
|
|
|
262 |
model_options_state = gr.State([model_options])
|
263 |
lora_options_state = gr.State([lora_options])
|
264 |
server_options_state = gr.State([server_options])
|
265 |
+
# uuid in db is used as user ID
|
266 |
+
my_db_state = gr.State([None, str(uuid.uuid4())])
|
267 |
chat_state = gr.State({})
|
268 |
# make user default first and default choice, dedup
|
269 |
docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
|
|
|
307 |
|
308 |
col_chat = gr.Column(visible=kwargs['chat'])
|
309 |
with col_chat:
|
310 |
+
instruction, submit, stop_btn = make_prompt_form(kwargs, LangChainMode)
|
311 |
text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
|
312 |
**kwargs)
|
313 |
|
|
|
356 |
value=kwargs['langchain_mode'],
|
357 |
label="Data Collection of Sources",
|
358 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
359 |
+
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
360 |
+
langchain_action = gr.Radio(
|
361 |
+
allowed_actions,
|
362 |
+
value=allowed_actions[0] if len(allowed_actions) > 0 else None,
|
363 |
+
label="Data Action",
|
364 |
+
visible=True)
|
365 |
data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
366 |
with data_row2:
|
367 |
with gr.Column(scale=50):
|
|
|
756 |
caption_loader=caption_loader,
|
757 |
verbose=kwargs['verbose'],
|
758 |
user_path=kwargs['user_path'],
|
759 |
+
n_jobs=kwargs['n_jobs'],
|
760 |
)
|
761 |
add_file_outputs = [fileup_output, langchain_mode, add_to_shared_db_btn, add_to_my_db_btn]
|
762 |
add_file_kwargs = dict(fn=update_user_db_func,
|
|
|
835 |
caption_loader=caption_loader,
|
836 |
verbose=kwargs['verbose'],
|
837 |
user_path=kwargs['user_path'],
|
838 |
+
n_jobs=kwargs['n_jobs'],
|
839 |
)
|
840 |
|
841 |
add_my_file_outputs = [fileup_output, langchain_mode, my_db_state, add_to_shared_db_btn, add_to_my_db_btn]
|
|
|
952 |
for k in inputs_kwargs_list:
|
953 |
assert k in kwargs_evaluate, "Missing %s" % k
|
954 |
|
955 |
+
def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
|
956 |
+
args_list = list(args1)
|
957 |
+
if str_api:
|
958 |
+
user_kwargs = args_list[2]
|
959 |
+
assert isinstance(user_kwargs, str)
|
960 |
+
user_kwargs = ast.literal_eval(user_kwargs)
|
961 |
+
else:
|
962 |
+
user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[2:])}
|
963 |
+
# only used for submit_nochat_api
|
964 |
+
user_kwargs['chat'] = False
|
965 |
+
if 'stream_output' not in user_kwargs:
|
966 |
+
user_kwargs['stream_output'] = False
|
967 |
+
if 'langchain_mode' not in user_kwargs:
|
968 |
+
# if user doesn't specify, then assume disabled, not use default
|
969 |
+
user_kwargs['langchain_mode'] = 'Disabled'
|
970 |
+
if 'langchain_action' not in user_kwargs:
|
971 |
+
user_kwargs['langchain_action'] = LangChainAction.QUERY.value
|
972 |
+
|
973 |
+
set1 = set(list(default_kwargs1.keys()))
|
974 |
+
set2 = set(eval_func_param_names)
|
975 |
+
assert set1 == set2, "Set diff: %s %s: %s" % (set1, set2, set1.symmetric_difference(set2))
|
976 |
+
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
977 |
+
model_state1 = args_list[0]
|
978 |
+
my_db_state1 = args_list[1]
|
979 |
+
args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
|
980 |
+
in eval_func_param_names]
|
981 |
+
assert len(args_list) == len(eval_func_param_names)
|
982 |
+
args_list = [model_state1, my_db_state1] + args_list
|
983 |
|
984 |
+
try:
|
985 |
+
for res_dict in evaluate(*tuple(args_list), **kwargs1):
|
986 |
+
if str_api:
|
987 |
+
# full return of dict
|
988 |
+
yield res_dict
|
989 |
+
elif kwargs['langchain_mode'] == 'Disabled':
|
990 |
+
yield fix_text_for_gradio(res_dict['response'])
|
991 |
+
else:
|
992 |
+
yield '<br>' + fix_text_for_gradio(res_dict['response'])
|
993 |
+
finally:
|
994 |
+
clear_torch_cache()
|
995 |
+
clear_embeddings(user_kwargs['langchain_mode'], my_db_state1)
|
996 |
+
|
997 |
+
fun = partial(evaluate_nochat,
|
998 |
+
default_kwargs1=default_kwargs,
|
999 |
+
str_api=False,
|
1000 |
**kwargs_evaluate)
|
1001 |
+
fun2 = partial(evaluate_nochat,
|
1002 |
+
default_kwargs1=default_kwargs,
|
1003 |
+
str_api=False,
|
1004 |
**kwargs_evaluate)
|
1005 |
+
fun_with_dict_str = partial(evaluate_nochat,
|
1006 |
+
default_kwargs1=default_kwargs,
|
1007 |
+
str_api=True,
|
1008 |
**kwargs_evaluate
|
1009 |
)
|
1010 |
|
|
|
1144 |
User that fills history for bot
|
1145 |
:param args:
|
1146 |
:param undo:
|
1147 |
+
:param retry:
|
1148 |
:param sanitize_user_prompt:
|
|
|
1149 |
:return:
|
1150 |
"""
|
1151 |
args_list = list(args)
|
1152 |
user_message = args_list[eval_func_param_names.index('instruction')] # chat only
|
1153 |
input1 = args_list[eval_func_param_names.index('iinput')] # chat only
|
1154 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1155 |
+
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1156 |
+
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1157 |
+
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1158 |
if not prompt_type1:
|
1159 |
# shouldn't have to specify if CLI launched model
|
1160 |
prompt_type1 = kwargs['prompt_type']
|
|
|
1185 |
history[-1][1] = None
|
1186 |
return history
|
1187 |
if user_message1 in ['', None, '\n']:
|
1188 |
+
if langchain_action1 in LangChainAction.QUERY.value and \
|
1189 |
+
DocumentChoices.Only_All_Sources.name not in document_choice1 \
|
1190 |
+
or \
|
1191 |
+
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1192 |
+
# reject non-retry submit/enter
|
1193 |
+
return history
|
1194 |
user_message1 = fix_text_for_gradio(user_message1)
|
1195 |
return history + [[user_message1, None]]
|
1196 |
|
|
|
1226 |
else:
|
1227 |
return 2000
|
1228 |
|
1229 |
+
def prep_bot(*args, retry=False, which_model=0):
|
1230 |
"""
|
1231 |
|
1232 |
:param args:
|
1233 |
:param retry:
|
1234 |
+
:param which_model: identifies which model if doing model_lock
|
1235 |
+
API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
|
1236 |
:return: last element is True if should run bot, False if should just yield history
|
1237 |
"""
|
1238 |
# don't deepcopy, can contain model itself
|
|
|
1240 |
model_state1 = args_list[-3]
|
1241 |
my_db_state1 = args_list[-2]
|
1242 |
history = args_list[-1]
|
1243 |
+
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1244 |
+
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
1245 |
|
1246 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1247 |
return history, None, None, None
|
1248 |
|
1249 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
1250 |
+
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1251 |
+
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1252 |
+
document_choice1 = args_list[eval_func_param_names.index('document_choice')]
|
1253 |
if not history:
|
1254 |
print("No history", flush=True)
|
1255 |
history = []
|
|
|
1260 |
instruction1 = history[-1][0]
|
1261 |
history[-1][1] = None
|
1262 |
elif not instruction1:
|
1263 |
+
if langchain_action1 in LangChainAction.QUERY.value and \
|
1264 |
+
DocumentChoices.Only_All_Sources.name not in document_choice1 \
|
1265 |
+
or \
|
1266 |
+
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1267 |
+
# if not retrying, then reject empty query
|
1268 |
+
return history, None, None, None
|
1269 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
1270 |
# reject submit button if already filled and not retrying
|
1271 |
# None when not filling with '' to keep client happy
|
1272 |
return history, None, None, None
|
1273 |
|
1274 |
# shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
|
1275 |
+
prompt_type1, prompt_dict1 = update_prompt(prompt_type1, prompt_dict1, model_state1,
|
1276 |
+
which_model=which_model)
|
1277 |
+
# apply back to args_list for evaluate()
|
1278 |
+
args_list[eval_func_param_names.index('prompt_type')] = prompt_type1
|
1279 |
+
args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1
|
|
|
|
|
|
|
1280 |
|
1281 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1282 |
model_max_length1 = get_model_max_length(model_state1)
|
|
|
1350 |
for res in get_response(fun1, history):
|
1351 |
yield res
|
1352 |
finally:
|
1353 |
+
clear_torch_cache()
|
1354 |
clear_embeddings(langchain_mode1, my_db_state1)
|
1355 |
|
1356 |
def all_bot(*args, retry=False, model_states1=None):
|
|
|
1364 |
my_db_state1 = None # will be filled below by some bot
|
1365 |
try:
|
1366 |
gen_list = []
|
1367 |
+
for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
|
1368 |
args_list1 = args_list0.copy()
|
1369 |
args_list1.insert(-1, model_state1) # insert at -1 so is at -2
|
1370 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
|
|
1376 |
# so consistent with prep_bot()
|
1377 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1378 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1379 |
+
history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry,
|
1380 |
+
which_model=chatboti)
|
1381 |
gen1 = get_response(fun1, history)
|
1382 |
if stream_output1:
|
1383 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
|
|
1389 |
tgen0 = time.time()
|
1390 |
for res1 in itertools.zip_longest(*gen_list):
|
1391 |
if time.time() - tgen0 > max_time1:
|
1392 |
+
print("Took too long: %s" % max_time1, flush=True)
|
1393 |
break
|
1394 |
|
1395 |
bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
|
|
|
1824 |
|
1825 |
def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
|
1826 |
infer_devices, gpu_id):
|
1827 |
+
# ensure no API calls reach here
|
1828 |
+
if is_public:
|
1829 |
+
raise RuntimeError("Illegal access for %s" % model_name)
|
1830 |
# ensure old model removed from GPU memory
|
1831 |
if kwargs['debug']:
|
1832 |
print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
|
|
|
2253 |
clear_torch_cache()
|
2254 |
|
2255 |
|
2256 |
+
def get_lock_file(db1, langchain_mode):
|
2257 |
+
assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
|
2258 |
+
user_id = db1[1]
|
2259 |
+
base_path = 'locks'
|
2260 |
+
makedirs(base_path)
|
2261 |
+
lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
|
2262 |
+
return lock_file
|
2263 |
+
|
2264 |
+
|
2265 |
def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
|
2266 |
user_path=None,
|
2267 |
use_openai_embedding=None,
|
|
|
2271 |
captions_model=None,
|
2272 |
enable_ocr=None,
|
2273 |
verbose=None,
|
2274 |
+
is_url=None, is_txt=None,
|
2275 |
+
n_jobs=-1):
|
2276 |
assert use_openai_embedding is not None
|
2277 |
assert hf_embedding_model is not None
|
2278 |
assert caption_loader is not None
|
|
|
2313 |
print("Adding %s" % file, flush=True)
|
2314 |
sources = path_to_docs(file if not is_url and not is_txt else None,
|
2315 |
verbose=verbose,
|
2316 |
+
n_jobs=n_jobs,
|
2317 |
chunk=chunk, chunk_size=chunk_size,
|
2318 |
url=file if is_url else None,
|
2319 |
text=file if is_txt else None,
|
|
|
2325 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2326 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2327 |
|
2328 |
+
lock_file = get_lock_file(db1, langchain_mode)
|
2329 |
+
with filelock.FileLock(lock_file):
|
2330 |
if langchain_mode == 'MyData':
|
2331 |
if db1[0] is not None:
|
2332 |
# then add
|
|
|
2339 |
# for production hit, when user gets clicky:
|
2340 |
assert len(db1) == 2, "Bad MyData db: %s" % db1
|
2341 |
# then create
|
|
|
2342 |
# if added has to original state and didn't change, then would be shared db for all users
|
|
|
2343 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
2344 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2345 |
db_type=db_type,
|
2346 |
persist_directory=persist_directory,
|
2347 |
langchain_mode=langchain_mode,
|
2348 |
hf_embedding_model=hf_embedding_model)
|
2349 |
+
if db is not None:
|
|
|
|
|
2350 |
db1[0] = db
|
2351 |
source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
|
2352 |
return None, langchain_mode, db1, x, y, source_files_added
|
|
|
2374 |
|
2375 |
|
2376 |
def get_db(db1, langchain_mode, dbs=None):
|
2377 |
+
lock_file = get_lock_file(db1, langchain_mode)
|
2378 |
+
|
2379 |
+
with filelock.FileLock(lock_file):
|
2380 |
if langchain_mode in ['wiki_full']:
|
2381 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2382 |
db = None
|
gradio_utils/__pycache__/grclient.cpython-310.pyc
CHANGED
Binary files a/gradio_utils/__pycache__/grclient.cpython-310.pyc and b/gradio_utils/__pycache__/grclient.cpython-310.pyc differ
|
|
gradio_utils/__pycache__/prompt_form.cpython-310.pyc
CHANGED
Binary files a/gradio_utils/__pycache__/prompt_form.cpython-310.pyc and b/gradio_utils/__pycache__/prompt_form.cpython-310.pyc differ
|
|
gradio_utils/prompt_form.py
CHANGED
@@ -95,11 +95,15 @@ def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
|
95 |
return text_output, text_output2, text_outputs
|
96 |
|
97 |
|
98 |
-
def make_prompt_form(kwargs):
|
|
|
|
|
|
|
|
|
99 |
if kwargs['input_lines'] > 1:
|
100 |
-
instruction_label = "Shift-Enter to Submit, Enter for more lines"
|
101 |
else:
|
102 |
-
instruction_label = "Enter to Submit, Shift-Enter for more lines"
|
103 |
|
104 |
with gr.Row():#elem_id='prompt-form-area'):
|
105 |
with gr.Column(scale=50):
|
|
|
95 |
return text_output, text_output2, text_outputs
|
96 |
|
97 |
|
98 |
+
def make_prompt_form(kwargs, LangChainMode):
|
99 |
+
if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
|
100 |
+
extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
|
101 |
+
else:
|
102 |
+
extra_prompt_form = ""
|
103 |
if kwargs['input_lines'] > 1:
|
104 |
+
instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
|
105 |
else:
|
106 |
+
instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
107 |
|
108 |
with gr.Row():#elem_id='prompt-form-area'):
|
109 |
with gr.Column(scale=50):
|
h2oai_pipeline.py
CHANGED
@@ -136,6 +136,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
|
|
139 |
return records
|
140 |
|
141 |
def _forward(self, model_inputs, **generate_kwargs):
|
|
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
139 |
+
print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
|
140 |
return records
|
141 |
|
142 |
def _forward(self, model_inputs, **generate_kwargs):
|
prompter.py
CHANGED
@@ -120,7 +120,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context,
|
|
120 |
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
121 |
PromptType.custom.name]:
|
122 |
promptA = prompt_dict.get('promptA', '')
|
123 |
-
promptB = prompt_dict('promptB', '')
|
124 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
125 |
PreInput = prompt_dict.get('PreInput', '')
|
126 |
PreResponse = prompt_dict.get('PreResponse', '')
|
@@ -693,7 +693,9 @@ class Prompter(object):
|
|
693 |
output = clean_response(output)
|
694 |
elif prompt is None:
|
695 |
# then use most basic parsing like pipeline
|
696 |
-
if self.botstr
|
|
|
|
|
697 |
if self.humanstr:
|
698 |
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
699 |
else:
|
|
|
120 |
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
121 |
PromptType.custom.name]:
|
122 |
promptA = prompt_dict.get('promptA', '')
|
123 |
+
promptB = prompt_dict.get('promptB', '')
|
124 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
125 |
PreInput = prompt_dict.get('PreInput', '')
|
126 |
PreResponse = prompt_dict.get('PreResponse', '')
|
|
|
693 |
output = clean_response(output)
|
694 |
elif prompt is None:
|
695 |
# then use most basic parsing like pipeline
|
696 |
+
if not self.botstr:
|
697 |
+
pass
|
698 |
+
elif self.botstr in output:
|
699 |
if self.humanstr:
|
700 |
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
701 |
else:
|