pseudotensor commited on
Commit
9bcca78
1 Parent(s): 63ce3fa

Update with h2oGPT hash 1b295baace42908075b47f31a84b359d8c6b1e52

Browse files
Files changed (7) hide show
  1. client_test.py +5 -1
  2. finetune.py +1 -1
  3. generate.py +6 -3
  4. gpt_langchain.py +21 -4
  5. gradio_runner.py +144 -87
  6. prompter.py +2 -0
  7. utils.py +1 -1
client_test.py CHANGED
@@ -3,7 +3,7 @@ Client test.
3
 
4
  Run server:
5
 
6
- python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
7
 
8
  NOTE: For private models, add --use-auth_token=True
9
 
@@ -39,6 +39,7 @@ Loaded as API: https://gpt.h2o.ai ✔
39
  import time
40
  import os
41
  import markdown # pip install markdown
 
42
  from bs4 import BeautifulSoup # pip install beautifulsoup4
43
 
44
  debug = False
@@ -79,6 +80,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
79
  instruction_nochat=prompt if not chat else '',
80
  iinput_nochat='', # only for chat=False
81
  langchain_mode='Disabled',
 
82
  )
83
  if chat:
84
  # add chatbot output on end. Assumes serialize=False
@@ -87,6 +89,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
87
  return kwargs, list(kwargs.values())
88
 
89
 
 
90
  def test_client_basic():
91
  return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
92
 
@@ -106,6 +109,7 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
106
  return res_dict
107
 
108
 
 
109
  def test_client_chat():
110
  return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
111
 
 
3
 
4
  Run server:
5
 
6
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
7
 
8
  NOTE: For private models, add --use-auth_token=True
9
 
 
39
  import time
40
  import os
41
  import markdown # pip install markdown
42
+ import pytest
43
  from bs4 import BeautifulSoup # pip install beautifulsoup4
44
 
45
  debug = False
 
80
  instruction_nochat=prompt if not chat else '',
81
  iinput_nochat='', # only for chat=False
82
  langchain_mode='Disabled',
83
+ document_choice=['All'],
84
  )
85
  if chat:
86
  # add chatbot output on end. Assumes serialize=False
 
89
  return kwargs, list(kwargs.values())
90
 
91
 
92
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
93
  def test_client_basic():
94
  return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
95
 
 
109
  return res_dict
110
 
111
 
112
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
113
  def test_client_chat():
114
  return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
115
 
finetune.py CHANGED
@@ -26,7 +26,7 @@ def train(
26
  save_code: bool = False,
27
  run_id: int = None,
28
 
29
- base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
30
  # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
31
  # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
32
  # base_model: str = 'EleutherAI/gpt-neox-20b',
 
26
  save_code: bool = False,
27
  run_id: int = None,
28
 
29
+ base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
30
  # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
31
  # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
32
  # base_model: str = 'EleutherAI/gpt-neox-20b',
generate.py CHANGED
@@ -297,7 +297,7 @@ def main(
297
  if psutil.virtual_memory().available < 94 * 1024 ** 3:
298
  # 12B uses ~94GB
299
  # 6.9B uses ~47GB
300
- base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' if not base_model else base_model
301
 
302
  # get defaults
303
  model_lower = base_model.lower()
@@ -864,6 +864,7 @@ eval_func_param_names = ['instruction',
864
  'instruction_nochat',
865
  'iinput_nochat',
866
  'langchain_mode',
 
867
  ]
868
 
869
 
@@ -891,6 +892,7 @@ def evaluate(
891
  instruction_nochat,
892
  iinput_nochat,
893
  langchain_mode,
 
894
  # END NOTE: Examples must have same order of parameters
895
  src_lang=None,
896
  tgt_lang=None,
@@ -1010,6 +1012,7 @@ def evaluate(
1010
  chunk=chunk,
1011
  chunk_size=chunk_size,
1012
  langchain_mode=langchain_mode,
 
1013
  db_type=db_type,
1014
  k=k,
1015
  temperature=temperature,
@@ -1446,7 +1449,7 @@ y = np.random.randint(0, 1, 100)
1446
 
1447
  # move to correct position
1448
  for example in examples:
1449
- example += [chat, '', '', 'Disabled']
1450
  # adjust examples if non-chat mode
1451
  if not chat:
1452
  example[eval_func_param_names.index('instruction_nochat')] = example[
@@ -1546,6 +1549,6 @@ if __name__ == "__main__":
1546
  can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
1547
  python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1548
 
1549
- python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1550
  """
1551
  fire.Fire(main)
 
297
  if psutil.virtual_memory().available < 94 * 1024 ** 3:
298
  # 12B uses ~94GB
299
  # 6.9B uses ~47GB
300
+ base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
301
 
302
  # get defaults
303
  model_lower = base_model.lower()
 
864
  'instruction_nochat',
865
  'iinput_nochat',
866
  'langchain_mode',
867
+ 'document_choice',
868
  ]
869
 
870
 
 
892
  instruction_nochat,
893
  iinput_nochat,
894
  langchain_mode,
895
+ document_choice,
896
  # END NOTE: Examples must have same order of parameters
897
  src_lang=None,
898
  tgt_lang=None,
 
1012
  chunk=chunk,
1013
  chunk_size=chunk_size,
1014
  langchain_mode=langchain_mode,
1015
+ document_choice=document_choice,
1016
  db_type=db_type,
1017
  k=k,
1018
  temperature=temperature,
 
1449
 
1450
  # move to correct position
1451
  for example in examples:
1452
+ example += [chat, '', '', 'Disabled', ['All']]
1453
  # adjust examples if non-chat mode
1454
  if not chat:
1455
  example[eval_func_param_names.index('instruction_nochat')] = example[
 
1549
  can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
1550
  python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1551
 
1552
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
1553
  """
1554
  fire.Fire(main)
gpt_langchain.py CHANGED
@@ -150,7 +150,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
150
  assert model_name is None
151
  assert tokenizer is None
152
  model_name = 'h2oai/h2ogpt-oasst1-512-12b'
153
- # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6.9b'
154
  # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
155
  tokenizer = AutoTokenizer.from_pretrained(model_name)
156
  device, torch_dtype, context_class = get_device_dtype()
@@ -593,7 +593,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
593
  ):
594
  globs_image_types = []
595
  globs_non_image_types = []
596
- if path_or_paths is None:
597
  return []
598
  elif url:
599
  globs_non_image_types = [url]
@@ -846,6 +846,7 @@ def _run_qa_db(query=None,
846
  top_k=40,
847
  top_p=0.7,
848
  langchain_mode=None,
 
849
  n_jobs=-1):
850
  """
851
 
@@ -917,7 +918,23 @@ def _run_qa_db(query=None,
917
  k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
918
 
919
  if db and use_context:
920
- docs_with_score = db.similarity_search_with_score(query, k=k_db)[:k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  # cut off so no high distance docs/sources considered
922
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
923
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
@@ -939,7 +956,7 @@ def _run_qa_db(query=None,
939
  reduced_query_words = reduced_query.split(' ')
940
  set_common = set(df['Lemma'].values.tolist())
941
  num_common = len([x.lower() in set_common for x in reduced_query_words])
942
- frac_common = num_common / len(reduced_query)
943
  # FIXME: report to user bad query that uses too many common words
944
  print("frac_common: %s" % frac_common, flush=True)
945
 
 
150
  assert model_name is None
151
  assert tokenizer is None
152
  model_name = 'h2oai/h2ogpt-oasst1-512-12b'
153
+ # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
154
  # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
155
  tokenizer = AutoTokenizer.from_pretrained(model_name)
156
  device, torch_dtype, context_class = get_device_dtype()
 
593
  ):
594
  globs_image_types = []
595
  globs_non_image_types = []
596
+ if not path_or_paths and not url and not text:
597
  return []
598
  elif url:
599
  globs_non_image_types = [url]
 
846
  top_k=40,
847
  top_p=0.7,
848
  langchain_mode=None,
849
+ document_choice=['All'],
850
  n_jobs=-1):
851
  """
852
 
 
918
  k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
919
 
920
  if db and use_context:
921
+ if isinstance(document_choice, str):
922
+ # support string as well
923
+ document_choice = [document_choice]
924
+ if not isinstance(db, Chroma) or len(document_choice) <= 1 and document_choice[0].lower() == 'all':
925
+ # treat empty list as All for now, not 'None'
926
+ filter_kwargs = {}
927
+ else:
928
+ if len(document_choice) >= 2:
929
+ or_filter = [{"source": {"$eq": x}} for x in document_choice]
930
+ filter_kwargs = dict(filter={"$or": or_filter})
931
+ else:
932
+ one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
933
+ filter_kwargs = dict(filter=one_filter)
934
+ if len(document_choice) == 1 and document_choice[0].lower() == 'none':
935
+ k_db = 1
936
+ k = 0
937
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
938
  # cut off so no high distance docs/sources considered
939
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
940
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
 
956
  reduced_query_words = reduced_query.split(' ')
957
  set_common = set(df['Lemma'].values.tolist())
958
  num_common = len([x.lower() in set_common for x in reduced_query_words])
959
+ frac_common = num_common / len(reduced_query) if reduced_query else 0
960
  # FIXME: report to user bad query that uses too many common words
961
  print("frac_common: %s" % frac_common, flush=True)
962
 
gradio_runner.py CHANGED
@@ -96,7 +96,13 @@ def go_gradio(**kwargs):
96
  css_code = """footer {visibility: hidden}"""
97
  css_code += """
98
  body.dark{#warning {background-color: #555555};}
99
- """
 
 
 
 
 
 
100
 
101
  if kwargs['gradio_avoid_processing_markdown']:
102
  from gradio_client import utils as client_utils
@@ -167,6 +173,7 @@ body.dark{#warning {background-color: #555555};}
167
  lora_options_state = gr.State([lora_options])
168
  my_db_state = gr.State([None, None])
169
  chat_state = gr.State({})
 
170
  gr.Markdown(f"""
171
  {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
172
 
@@ -175,7 +182,7 @@ body.dark{#warning {background-color: #555555};}
175
  """)
176
  if is_hf:
177
  gr.HTML(
178
- )
179
 
180
  # go button visible if
181
  base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
@@ -220,7 +227,7 @@ body.dark{#warning {background-color: #555555};}
220
  submit = gr.Button(value='Submit').style(full_width=False, size='sm')
221
  stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
222
  with gr.Row():
223
- clear = gr.Button("Save, New Conversation")
224
  flag_btn = gr.Button("Flag")
225
  if not kwargs['auto_score']: # FIXME: For checkbox model2
226
  with gr.Column(visible=kwargs['score_model']):
@@ -251,19 +258,16 @@ body.dark{#warning {background-color: #555555};}
251
  radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
252
  type='value')
253
  with gr.Row():
254
- remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
255
  clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
256
- chats_row = gr.Row(visible=True).style(equal_height=False)
257
- with chats_row:
258
- export_chats_btn = gr.Button(value="Export Chats")
259
- chats_file = gr.File(interactive=False, label="Download File")
260
- chats_row2 = gr.Row(visible=True).style(equal_height=False)
261
- with chats_row2:
262
  chatsup_output = gr.File(label="Upload Chat File(s)",
263
  file_types=['.json'],
264
  file_count='multiple',
265
  elem_id="warning", elem_classes="feedback")
266
- add_to_chats_btn = gr.Button("Add File(s) to Chats")
267
  with gr.TabItem("Data Source"):
268
  langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
269
  from_str=True)
@@ -275,8 +279,8 @@ body.dark{#warning {background-color: #555555};}
275
  <p>
276
  For more options see: {langchain_readme}""",
277
  visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
278
- data_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
279
- with data_row:
280
  if is_hf:
281
  # don't show 'wiki' since only usually useful for internal testing at moment
282
  no_show_modes = ['Disabled', 'wiki']
@@ -292,77 +296,92 @@ body.dark{#warning {background-color: #555555};}
292
  langchain_mode = gr.Radio(
293
  [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
294
  value=kwargs['langchain_mode'],
295
- label="Data Source",
296
  visible=kwargs['langchain_mode'] != 'Disabled')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
- def upload_file(files, x):
299
- file_paths = [file.name for file in files]
300
- return files, file_paths
301
-
302
- upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
303
- equal_height=False)
304
  # import control
305
  if kwargs['langchain_mode'] != 'Disabled':
306
  from gpt_langchain import file_types, have_arxiv
307
  else:
308
  have_arxiv = False
309
  file_types = []
310
- with upload_row:
311
- file_types_str = '[' + ' '.join(file_types) + ']'
312
- fileup_output = gr.File(label=f'Upload {file_types_str}',
313
- file_types=file_types,
314
- file_count="multiple",
315
- elem_id="warning", elem_classes="feedback")
316
- with gr.Row():
317
- upload_button = gr.UploadButton("Upload %s" % file_types_str,
318
- file_types=file_types,
319
- file_count="multiple",
320
- visible=False,
321
- )
322
- # add not visible until upload something
323
- with gr.Column():
324
- add_to_shared_db_btn = gr.Button("Add File(s) to Shared UserData DB",
325
- visible=allow_upload_to_user_data) # and False)
326
- add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData DB",
327
- visible=allow_upload_to_my_data) # and False)
328
- url_row = gr.Row(
329
- visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload).style(
330
  equal_height=False)
331
- with url_row:
332
- url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
333
- url_text = gr.Textbox(label=url_label, interactive=True)
334
  with gr.Column():
335
- url_user_btn = gr.Button(value='Add URL content to Shared UserData DB',
336
- visible=allow_upload_to_user_data)
337
- url_my_btn = gr.Button(value='Add URL content to Scratch MyData DB',
338
- visible=allow_upload_to_my_data)
339
- text_row = gr.Row(
340
- visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload).style(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  equal_height=False)
342
- with text_row:
343
- user_text_text = gr.Textbox(label='Paste Text', interactive=True)
344
- with gr.Column():
345
- user_text_user_btn = gr.Button(value='Add Text to Shared UserData DB',
346
- visible=allow_upload_to_user_data)
347
- user_text_my_btn = gr.Button(value='Add Text to Scratch MyData DB',
348
- visible=allow_upload_to_my_data)
349
- # WIP:
350
- with gr.Row(visible=False).style(equal_height=False):
351
- github_textbox = gr.Textbox(label="Github URL")
352
- with gr.Row(visible=True):
353
- github_shared_btn = gr.Button(value="Add Github to Shared UserData DB",
354
- visible=allow_upload_to_user_data)
355
- github_my_btn = gr.Button(value="Add Github to Scratch MyData DB",
356
- visible=allow_upload_to_my_data)
357
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
358
  equal_height=False)
359
  with sources_row:
360
  sources_text = gr.HTML(label='Sources Added', interactive=False)
361
- sources_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
362
- equal_height=False)
363
- with sources_row2:
364
- get_sources_btn = gr.Button(value="Get Sources List for Selected DB")
365
- file_source = gr.File(interactive=False, label="Download File with list of Sources")
366
 
367
  with gr.TabItem("Expert"):
368
  with gr.Row():
@@ -545,14 +564,6 @@ body.dark{#warning {background-color: #555555};}
545
  def make_visible():
546
  return gr.update(visible=True)
547
 
548
- # add itself to output to ensure shows working and can't click again
549
- upload_button.upload(upload_file, inputs=[upload_button, fileup_output],
550
- outputs=[upload_button, fileup_output], queue=queue,
551
- api_name='upload_file' if allow_api else None) \
552
- .then(make_add_visible, fileup_output, add_to_shared_db_btn, queue=queue) \
553
- .then(make_add_visible, fileup_output, add_to_my_db_btn, queue=queue) \
554
- .then(make_invisible, outputs=upload_button, queue=queue)
555
-
556
  # Add to UserData
557
  update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
558
  use_openai_embedding=use_openai_embedding,
@@ -623,8 +634,23 @@ body.dark{#warning {background-color: #555555};}
623
  .then(clear_textbox, outputs=user_text_text, queue=queue)
624
 
625
  get_sources1 = functools.partial(get_sources, dbs=dbs)
626
- get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=file_source, queue=queue,
627
- api_name='get_sources' if allow_api else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
 
629
  def check_admin_pass(x):
630
  return gr.update(visible=x == admin_pass)
@@ -818,6 +844,11 @@ body.dark{#warning {background-color: #555555};}
818
  my_db_state1 = args_list[-2]
819
  history = args_list[-1]
820
 
 
 
 
 
 
821
  args_list = args_list[:-3] # only keep rest needed for evaluate()
822
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
823
  if retry and history:
@@ -827,13 +858,19 @@ body.dark{#warning {background-color: #555555};}
827
  args_list[eval_func_param_names.index('do_sample')] = True
828
  if not history:
829
  print("No history", flush=True)
830
- history = [['', None]]
831
  yield history, ''
832
  return
833
  # ensure output will be unique to models
834
  _, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
835
  history = copy.deepcopy(history)
836
  instruction1 = history[-1][0]
 
 
 
 
 
 
837
  context1 = ''
838
  if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
839
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
@@ -867,10 +904,6 @@ body.dark{#warning {background-color: #555555};}
867
  context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
868
  args_list[0] = instruction1 # override original instruction with history from user
869
  args_list[2] = context1
870
- if model_state1[0] is None or model_state1[0] == no_model_str:
871
- history = [['', None]]
872
- yield history, ''
873
- return
874
  fun1 = partial(evaluate,
875
  model_state1,
876
  my_db_state1,
@@ -1086,10 +1119,14 @@ body.dark{#warning {background-color: #555555};}
1086
  api_name='export_chats' if allow_api else None)
1087
 
1088
  def add_chats_from_file(file, chat_state1, add_btn):
 
 
1089
  if isinstance(file, str):
1090
  files = [file]
1091
  else:
1092
  files = file
 
 
1093
  for file1 in files:
1094
  try:
1095
  if hasattr(file1, 'name'):
@@ -1350,22 +1387,28 @@ def get_inputs_list(inputs_dict, model_lower):
1350
  def get_sources(db1, langchain_mode, dbs=None):
1351
  if langchain_mode in ['ChatLLM', 'LLM']:
1352
  source_files_added = "NA"
 
1353
  elif langchain_mode in ['wiki_full']:
1354
  source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
1355
  " Ask [email protected] for file if required."
 
1356
  elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
1357
  db_get = db1[0].get()
1358
- source_files_added = '\n'.join(sorted(set([x['source'] for x in db_get['metadatas']])))
 
1359
  elif langchain_mode in dbs and dbs[langchain_mode] is not None:
1360
  db1 = dbs[langchain_mode]
1361
  db_get = db1.get()
1362
- source_files_added = '\n'.join(sorted(set([x['source'] for x in db_get['metadatas']])))
 
1363
  else:
 
1364
  source_files_added = "None"
1365
  sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
1366
  with open(sources_file, "wt") as f:
1367
  f.write(source_files_added)
1368
- return sources_file
 
1369
 
1370
 
1371
  def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
@@ -1465,6 +1508,20 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
1465
  return x, y, source_files_added
1466
 
1467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1468
  def get_source_files(db, exceptions=None):
1469
  if exceptions is None:
1470
  exceptions = []
 
96
  css_code = """footer {visibility: hidden}"""
97
  css_code += """
98
  body.dark{#warning {background-color: #555555};}
99
+ #small_btn {
100
+ margin: 0.6em 0em 0.55em 0;
101
+ max-width: 20em;
102
+ min-width: 5em !important;
103
+ height: 5em;
104
+ font-size: 14px !important
105
+ }"""
106
 
107
  if kwargs['gradio_avoid_processing_markdown']:
108
  from gradio_client import utils as client_utils
 
173
  lora_options_state = gr.State([lora_options])
174
  my_db_state = gr.State([None, None])
175
  chat_state = gr.State({})
176
+ docs_state = gr.State(['All'])
177
  gr.Markdown(f"""
178
  {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
179
 
 
182
  """)
183
  if is_hf:
184
  gr.HTML(
185
+ )
186
 
187
  # go button visible if
188
  base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
 
227
  submit = gr.Button(value='Submit').style(full_width=False, size='sm')
228
  stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
229
  with gr.Row():
230
+ clear = gr.Button("Save Chat / New Chat")
231
  flag_btn = gr.Button("Flag")
232
  if not kwargs['auto_score']: # FIXME: For checkbox model2
233
  with gr.Column(visible=kwargs['score_model']):
 
258
  radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
259
  type='value')
260
  with gr.Row():
 
261
  clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
262
+ export_chats_btn = gr.Button(value="Export Chats to Download")
263
+ remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
264
+ add_to_chats_btn = gr.Button("Import Chats from Upload")
265
+ with gr.Row():
266
+ chats_file = gr.File(interactive=False, label="Download Exported Chats")
 
267
  chatsup_output = gr.File(label="Upload Chat File(s)",
268
  file_types=['.json'],
269
  file_count='multiple',
270
  elem_id="warning", elem_classes="feedback")
 
271
  with gr.TabItem("Data Source"):
272
  langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
273
  from_str=True)
 
279
  <p>
280
  For more options see: {langchain_readme}""",
281
  visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
282
+ data_row1 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
283
+ with data_row1:
284
  if is_hf:
285
  # don't show 'wiki' since only usually useful for internal testing at moment
286
  no_show_modes = ['Disabled', 'wiki']
 
296
  langchain_mode = gr.Radio(
297
  [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
298
  value=kwargs['langchain_mode'],
299
+ label="Data Collection of Sources",
300
  visible=kwargs['langchain_mode'] != 'Disabled')
301
+ data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
302
+ with data_row2:
303
+ with gr.Column(scale=50):
304
+ document_choice = gr.Dropdown(docs_state.value,
305
+ label="Choose Subset of Doc(s) in Collection [click get to update]",
306
+ value=docs_state.value[0],
307
+ interactive=True,
308
+ multiselect=True,
309
+ )
310
+ with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list):
311
+ get_sources_btn = gr.Button(value="Get Sources",
312
+ ).style(full_width=False, size='sm')
313
+ show_sources_btn = gr.Button(value="Show Sources",
314
+ ).style(full_width=False, size='sm')
315
 
 
 
 
 
 
 
316
  # import control
317
  if kwargs['langchain_mode'] != 'Disabled':
318
  from gpt_langchain import file_types, have_arxiv
319
  else:
320
  have_arxiv = False
321
  file_types = []
322
+
323
+ upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  equal_height=False)
325
+ with upload_row:
 
 
326
  with gr.Column():
327
+ file_types_str = '[' + ' '.join(file_types) + ']'
328
+ fileup_output = gr.File(label=f'Upload {file_types_str}',
329
+ file_types=file_types,
330
+ file_count="multiple",
331
+ elem_id="warning", elem_classes="feedback")
332
+ with gr.Row():
333
+ add_to_shared_db_btn = gr.Button("Add File(s) to UserData",
334
+ visible=allow_upload_to_user_data, elem_id='small_btn')
335
+ add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData",
336
+ visible=allow_upload_to_my_data,
337
+ elem_id='small_btn' if allow_upload_to_user_data else None,
338
+ ).style(
339
+ size='sm' if not allow_upload_to_user_data else None)
340
+ with gr.Column(
341
+ visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload):
342
+ url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
343
+ url_text = gr.Textbox(label=url_label, interactive=True)
344
+ with gr.Row():
345
+ url_user_btn = gr.Button(value='Add URL content to Shared UserData',
346
+ visible=allow_upload_to_user_data, elem_id='small_btn')
347
+ url_my_btn = gr.Button(value='Add URL content to Scratch MyData',
348
+ visible=allow_upload_to_my_data,
349
+ elem_id='small_btn' if allow_upload_to_user_data else None,
350
+ ).style(size='sm' if not allow_upload_to_user_data else None)
351
+ with gr.Column(
352
+ visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload):
353
+ user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]', interactive=True)
354
+ with gr.Row():
355
+ user_text_user_btn = gr.Button(value='Add Text to Shared UserData',
356
+ visible=allow_upload_to_user_data,
357
+ elem_id='small_btn')
358
+ user_text_my_btn = gr.Button(value='Add Text to Scratch MyData',
359
+ visible=allow_upload_to_my_data,
360
+ elem_id='small_btn' if allow_upload_to_user_data else None,
361
+ ).style(
362
+ size='sm' if not allow_upload_to_user_data else None)
363
+ with gr.Column(visible=False):
364
+ # WIP:
365
+ with gr.Row(visible=False).style(equal_height=False):
366
+ github_textbox = gr.Textbox(label="Github URL")
367
+ with gr.Row(visible=True):
368
+ github_shared_btn = gr.Button(value="Add Github to Shared UserData",
369
+ visible=allow_upload_to_user_data,
370
+ elem_id='small_btn')
371
+ github_my_btn = gr.Button(value="Add Github to Scratch MyData",
372
+ visible=allow_upload_to_my_data, elem_id='small_btn')
373
+ sources_row3 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
374
  equal_height=False)
375
+ with sources_row3:
376
+ with gr.Column(scale=1):
377
+ file_source = gr.File(interactive=False,
378
+ label="Download File with Sources [click get to make file]")
379
+ with gr.Column(scale=2):
380
+ pass
 
 
 
 
 
 
 
 
 
381
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
382
  equal_height=False)
383
  with sources_row:
384
  sources_text = gr.HTML(label='Sources Added', interactive=False)
 
 
 
 
 
385
 
386
  with gr.TabItem("Expert"):
387
  with gr.Row():
 
564
  def make_visible():
565
  return gr.update(visible=True)
566
 
 
 
 
 
 
 
 
 
567
  # Add to UserData
568
  update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
569
  use_openai_embedding=use_openai_embedding,
 
634
  .then(clear_textbox, outputs=user_text_text, queue=queue)
635
 
636
  get_sources1 = functools.partial(get_sources, dbs=dbs)
637
+
638
+ # if change collection source, must clear doc selections from it to avoid inconsistency
639
+ def clear_doc_choice():
640
+ return gr.Dropdown.update(choices=['All'], value=['All'])
641
+
642
+ langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
643
+
644
+ def update_dropdown(x):
645
+ return gr.Dropdown.update(choices=x, value='All')
646
+
647
+ show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
648
+ get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
649
+ queue=queue,
650
+ api_name='get_sources' if allow_api else None) \
651
+ .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
652
+ # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
653
+ show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text)
654
 
655
  def check_admin_pass(x):
656
  return gr.update(visible=x == admin_pass)
 
844
  my_db_state1 = args_list[-2]
845
  history = args_list[-1]
846
 
847
+ if model_state1[0] is None or model_state1[0] == no_model_str:
848
+ history = []
849
+ yield history, ''
850
+ return
851
+
852
  args_list = args_list[:-3] # only keep rest needed for evaluate()
853
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
854
  if retry and history:
 
858
  args_list[eval_func_param_names.index('do_sample')] = True
859
  if not history:
860
  print("No history", flush=True)
861
+ history = []
862
  yield history, ''
863
  return
864
  # ensure output will be unique to models
865
  _, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
866
  history = copy.deepcopy(history)
867
  instruction1 = history[-1][0]
868
+ if not instruction1:
869
+ # reject empty query, can sometimes go nuts
870
+ history = []
871
+ yield history, ''
872
+ return
873
+
874
  context1 = ''
875
  if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
876
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
 
904
  context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
905
  args_list[0] = instruction1 # override original instruction with history from user
906
  args_list[2] = context1
 
 
 
 
907
  fun1 = partial(evaluate,
908
  model_state1,
909
  my_db_state1,
 
1119
  api_name='export_chats' if allow_api else None)
1120
 
1121
  def add_chats_from_file(file, chat_state1, add_btn):
1122
+ if not file:
1123
+ return chat_state1, add_btn
1124
  if isinstance(file, str):
1125
  files = [file]
1126
  else:
1127
  files = file
1128
+ if not files:
1129
+ return chat_state1, add_btn
1130
  for file1 in files:
1131
  try:
1132
  if hasattr(file1, 'name'):
 
1387
  def get_sources(db1, langchain_mode, dbs=None):
1388
  if langchain_mode in ['ChatLLM', 'LLM']:
1389
  source_files_added = "NA"
1390
+ source_list = []
1391
  elif langchain_mode in ['wiki_full']:
1392
  source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
1393
  " Ask [email protected] for file if required."
1394
+ source_list = []
1395
  elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
1396
  db_get = db1[0].get()
1397
+ source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
1398
+ source_files_added = '\n'.join(source_list)
1399
  elif langchain_mode in dbs and dbs[langchain_mode] is not None:
1400
  db1 = dbs[langchain_mode]
1401
  db_get = db1.get()
1402
+ source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
1403
+ source_files_added = '\n'.join(source_list)
1404
  else:
1405
+ source_list = []
1406
  source_files_added = "None"
1407
  sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
1408
  with open(sources_file, "wt") as f:
1409
  f.write(source_files_added)
1410
+ source_list = ['All'] + source_list
1411
+ return sources_file, source_list
1412
 
1413
 
1414
  def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
 
1508
  return x, y, source_files_added
1509
 
1510
 
1511
+ def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
1512
+ with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
1513
+ if langchain_mode in ['wiki_full']:
1514
+ # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
1515
+ db = None
1516
+ elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
1517
+ db = db1[0]
1518
+ elif langchain_mode in dbs and dbs[langchain_mode] is not None:
1519
+ db = dbs[langchain_mode]
1520
+ else:
1521
+ db = None
1522
+ return get_source_files(db, exceptions=None)
1523
+
1524
+
1525
  def get_source_files(db, exceptions=None):
1526
  if exceptions is None:
1527
  exceptions = []
prompter.py CHANGED
@@ -56,6 +56,8 @@ prompt_type_to_model_name = {
56
  'h2oai/h2ogpt-oasst1-512-20b',
57
  'h2oai/h2ogpt-oig-oasst1-256-6_9b',
58
  'h2oai/h2ogpt-oig-oasst1-512-6_9b',
 
 
59
  'h2oai/h2ogpt-research-oasst1-512-30b', # private
60
  ],
61
  'dai_faq': [],
 
56
  'h2oai/h2ogpt-oasst1-512-20b',
57
  'h2oai/h2ogpt-oig-oasst1-256-6_9b',
58
  'h2oai/h2ogpt-oig-oasst1-512-6_9b',
59
+ 'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
60
+ 'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
61
  'h2oai/h2ogpt-research-oasst1-512-30b', # private
62
  ],
63
  'dai_faq': [],
utils.py CHANGED
@@ -148,7 +148,7 @@ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
148
  host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
149
  zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
150
  assert root_dirs is not None
151
- if not os.path.isdir(os.path.dirname(zip_file)):
152
  os.makedirs(os.path.dirname(zip_file), exist_ok=True)
153
  with zipfile.ZipFile(zip_file, "w") as expt_zip:
154
  for root_dir in root_dirs:
 
148
  host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
149
  zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
150
  assert root_dirs is not None
151
+ if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file):
152
  os.makedirs(os.path.dirname(zip_file), exist_ok=True)
153
  with zipfile.ZipFile(zip_file, "w") as expt_zip:
154
  for root_dir in root_dirs: