shaocongma commited on
Commit
94dc00e
·
1 Parent(s): 328e8d0

Add a generator wrapper using configuration file. Edit the logic of searching references. Add Gradio UI for testing Knowledge database.

Browse files
api_wrapper.py DELETED
@@ -1,42 +0,0 @@
1
- '''
2
- This script is used to wrap all generation methods together.
3
-
4
- todo:
5
- A worker keeps running on the server. Monitor the Amazon SQS. Once receive a new message, do the following:
6
- Download the corresponding configuration files on S3.
7
- Change Task status from Pending to Running.
8
- Call `generator_wrapper` and wait for the outputs.
9
- If `generator_wrapper` returns results:
10
- evaluate the results; compile it; upload results to S3 ... Change Task status from Running to Completed.
11
- If anything goes wrong, raise Error.
12
- If `generator_wrapper` returns nothing or Timeout, or raise any error:
13
- Change Task status from Running to Failed.
14
- '''
15
- import os.path
16
-
17
- from auto_backgrounds import generate_draft
18
- import json, time
19
- from utils.file_operations import make_archive
20
-
21
-
22
- GENERATOR_MAPPING = {"fake": None, # a fake generator
23
- "draft": generate_draft # generate academic paper
24
- }
25
-
26
- def generator_wrapper(config):
27
- generator = GENERATOR_MAPPING[config["generator"]]
28
-
29
-
30
- def generator_wrapper_from_json(path_to_config_json):
31
- # Read configuration file and call corresponding function
32
- with open(path_to_config_json, "r", encoding='utf-8') as f:
33
- config = json.load(f)
34
- print("Configuration:", config)
35
- # generator = GENERATOR_MAPPING.get(config["generator"])
36
- generator = None
37
- if generator is None:
38
- # generate a fake ZIP file and upload
39
- time.sleep(150)
40
- zip_path = os.path.splitext(path_to_config_json)[0]+".zip"
41
- return make_archive(path_to_config_json, zip_path)
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,9 +2,10 @@ import uuid
2
  import gradio as gr
3
  import os
4
  import openai
5
- from auto_backgrounds import generate_backgrounds, generate_draft
6
  from utils.file_operations import list_folders, urlify
7
  from huggingface_hub import snapshot_download
 
8
 
9
  # todo:
10
  # 6. get logs when the procedure is not completed. *
@@ -22,8 +23,10 @@ from huggingface_hub import snapshot_download
22
  # OPENAI_API_BASE: (Optional) Support alternative OpenAI minors
23
  # GPT4_ENABLE: (Optional) Set it to 1 to enable GPT-4 model.
24
 
25
- # AWS_ACCESS_KEY_ID: (Optional) Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
26
- # AWS_SECRET_ACCESS_KEY: (Optional) Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
 
 
27
  # KDB_REPO: (Optional) A Huggingface dataset hosting Knowledge Databases
28
  # HF_TOKEN: (Optional) Access to KDB_REPO
29
 
@@ -34,7 +37,7 @@ openai_key = os.getenv("OPENAI_API_KEY")
34
  openai_api_base = os.getenv("OPENAI_API_BASE")
35
  if openai_api_base is not None:
36
  openai.api_base = openai_api_base
37
- GPT4_ENABLE = os.getenv("GPT4_ENABLE") # disable GPT-4 for public repo
38
 
39
  access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
40
  secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
@@ -124,7 +127,7 @@ REFERENCES = """## 一键搜索相关论文
124
  REFERENCES_INSTRUCTION = """### References
125
  这一栏用于定义AI如何选取参考文献. 目前是两种方式混合:
126
  1. GPT自动根据标题生成关键字,使用Semantic Scholar搜索引擎搜索文献,利用Specter获取Paper Embedding来自动选取最相关的文献作为GPT的参考资料.
127
- 2. 用户上传bibtex文件,使用Google Scholar搜索摘要作为GPT的参考资料.
128
  关于有希望利用本地文件来供GPT参考的功能将在未来实装.
129
  """
130
 
@@ -140,7 +143,7 @@ OUTPUTS_INSTRUCTION = """### Outputs
140
  这一栏用于定义输出的内容:
141
  * Template: 用于填装内容的LaTeX模板.
142
  * Models: 使用GPT-4或者GPT-3.5-Turbo生成内容.
143
- * Prompts模式: 不生成内容, 而是生成用于生成内容的Prompts. 可以手动复制到网页版或者其他语言模型中进行使用.
144
  """
145
 
146
  OTHERS_INSTRUCTION = """### Others
@@ -164,18 +167,34 @@ def clear_inputs(*args):
164
  def clear_inputs_refs(*args):
165
  return "", 5
166
 
 
167
  def wrapped_generator(
168
  paper_title, paper_description, # main input
169
- openai_api_key=None, openai_url=None, # key
170
- tldr=True, max_kw_refs=10, bib_refs=None, max_tokens_ref=2048, # references
171
  knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
172
  paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False, # outputs parameters
173
  cache_mode=IS_CACHE_AVAILABLE # handle cache mode
174
  ):
175
- # if `cache_mode` is True, then always upload the generated content to my S3.
176
  file_name_upload = urlify(paper_title) + "_" + uuid.uuid1().hex + ".zip"
177
- if bib_refs is not None:
178
- bib_refs = bib_refs.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  if openai_api_key is not None:
180
  openai.api_key = openai_api_key
181
  try:
@@ -183,12 +202,7 @@ def wrapped_generator(
183
  except Exception as e:
184
  raise gr.Error(f"Key错误. Error: {e}")
185
  try:
186
- output = generate_draft(
187
- paper_title, description=paper_description, # main input
188
- tldr=tldr, max_kw_refs=max_kw_refs, bib_refs=bib_refs, max_tokens_ref=max_tokens_ref, # references
189
- knowledge_database=knowledge_database, max_tokens_kd=max_tokens_kd, query_counts=query_counts, # domain knowledge
190
- sections=selected_sections, model=model, template=paper_template, prompts_mode=prompts_mode, # outputs parameters
191
- )
192
  if cache_mode:
193
  from utils.storage import upload_file
194
  upload_file(output, target_name=file_name_upload)
@@ -204,8 +218,6 @@ with gr.Blocks(theme=theme) as demo:
204
  with gr.Column(scale=2):
205
  key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
206
  visible=not IS_OPENAI_API_KEY_AVAILABLE)
207
- url = gr.Textbox(value=None, lines=1, max_lines=1, label="URL",
208
- visible=False)
209
  # 每个功能做一个tab
210
  with gr.Tab("学术论文"):
211
  gr.Markdown(ACADEMIC_PAPER)
@@ -230,8 +242,8 @@ with gr.Blocks(theme=theme) as demo:
230
  interactive=GPT4_INTERACTIVE,
231
  info="生成论文用到的语言模型.")
232
  prompts_mode = gr.Checkbox(value=False, visible=True, interactive=True,
233
- label="Prompts模式",
234
- info="只输出用于生成论文的Prompts, 可以复制到别的地方生成论文.")
235
 
236
  sections = gr.CheckboxGroup(
237
  choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
@@ -245,21 +257,27 @@ with gr.Blocks(theme=theme) as demo:
245
 
246
  with gr.Column(scale=2):
247
  max_kw_ref_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
248
- interactive=True, label="MAX_KW_REFS",
249
- info="每个Keyword搜索几篇参考文献", visible=False)
250
 
251
  max_tokens_ref_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
252
- interactive=True, label="MAX_TOKENS",
253
- info="参考文献内容占用Prompts中的Token数")
254
 
255
  tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
256
  info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
257
  interactive=True)
258
- gr.Markdown('''
259
- 上传.bib文件提供AI需要参考的文献.
260
- ''')
261
- bibtex_file = gr.File(label="Upload .bib file", file_types=["text"],
262
- interactive=True)
 
 
 
 
 
 
263
 
264
  with gr.Row():
265
  with gr.Column(scale=1):
@@ -267,11 +285,11 @@ with gr.Blocks(theme=theme) as demo:
267
 
268
  with gr.Column(scale=2):
269
  query_counts_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
270
- interactive=True, label="QUERY_COUNTS",
271
- info="从知识库内检索多少条内���", visible=False)
272
  max_tokens_kd_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
273
- interactive=True, label="MAX_TOKENS",
274
- info="知识库内容占用Prompts中的Token数")
275
  domain_knowledge = gr.Dropdown(label="预载知识库",
276
  choices=ALL_DATABASES,
277
  value="(None)",
@@ -296,8 +314,8 @@ with gr.Blocks(theme=theme) as demo:
296
  json_output = gr.JSON(label="References")
297
  clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
298
  submit_button_pp.click(fn=wrapped_generator,
299
- inputs=[title, description_pp, key, url,
300
- tldr_checkbox, max_kw_ref_slider, bibtex_file, max_tokens_ref_slider,
301
  domain_knowledge, max_tokens_kd_slider, query_counts_slider,
302
  template, sections, model_selection, prompts_mode], outputs=file_output)
303
 
 
2
  import gradio as gr
3
  import os
4
  import openai
5
+ import yaml
6
  from utils.file_operations import list_folders, urlify
7
  from huggingface_hub import snapshot_download
8
+ from wrapper import generator_wrapper
9
 
10
  # todo:
11
  # 6. get logs when the procedure is not completed. *
 
23
  # OPENAI_API_BASE: (Optional) Support alternative OpenAI minors
24
  # GPT4_ENABLE: (Optional) Set it to 1 to enable GPT-4 model.
25
 
26
+ # AWS_ACCESS_KEY_ID: (Optional)
27
+ # Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
28
+ # AWS_SECRET_ACCESS_KEY: (Optional)
29
+ # Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
30
  # KDB_REPO: (Optional) A Huggingface dataset hosting Knowledge Databases
31
  # HF_TOKEN: (Optional) Access to KDB_REPO
32
 
 
37
  openai_api_base = os.getenv("OPENAI_API_BASE")
38
  if openai_api_base is not None:
39
  openai.api_base = openai_api_base
40
+ GPT4_ENABLE = os.getenv("GPT4_ENABLE") # disable GPT-4 for public repo
41
 
42
  access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
43
  secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
 
127
  REFERENCES_INSTRUCTION = """### References
128
  这一栏用于定义AI如何选取参考文献. 目前是两种方式混合:
129
  1. GPT自动根据标题生成关键字,使用Semantic Scholar搜索引擎搜索文献,利用Specter获取Paper Embedding来自动选取最相关的文献作为GPT的参考资料.
130
+ 2. 用户通过输入文章标题(用英文逗号隔开), AI会自动搜索文献作为参考资料.
131
  关于有希望利用本地文件来供GPT参考的功能将在未来实装.
132
  """
133
 
 
143
  这一栏用于定义输出的内容:
144
  * Template: 用于填装内容的LaTeX模板.
145
  * Models: 使用GPT-4或者GPT-3.5-Turbo生成内容.
146
+ * Prompts模式: 不生成内容, 而是生成用于生成内容的Prompts. 可以手动复制到网页版或者其他语言模型中进行使用. (放在输出的ZIP文件的prompts.json文件中)
147
  """
148
 
149
  OTHERS_INSTRUCTION = """### Others
 
167
  def clear_inputs_refs(*args):
168
  return "", 5
169
 
170
+
171
  def wrapped_generator(
172
  paper_title, paper_description, # main input
173
+ openai_api_key=None, # key
174
+ tldr=True, max_kw_refs=10, refs=None, max_tokens_ref=2048, # references
175
  knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
176
  paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False, # outputs parameters
177
  cache_mode=IS_CACHE_AVAILABLE # handle cache mode
178
  ):
 
179
  file_name_upload = urlify(paper_title) + "_" + uuid.uuid1().hex + ".zip"
180
+
181
+ # load the default configuration file
182
+ with open("configurations/default.yaml", 'r') as file:
183
+ config = yaml.safe_load(file)
184
+ config["paper"]["title"] = paper_title
185
+ config["paper"]["description"] = paper_description
186
+ config["references"]["tldr"] = tldr
187
+ config["references"]["max_kw_refs"] = max_kw_refs
188
+ config["references"]["refs"] = refs
189
+ config["references"]["max_tokens_ref"] = max_tokens_ref
190
+ config["domain_knowledge"]["knowledge_database"] = knowledge_database
191
+ config["domain_knowledge"]["max_tokens_kd"] = max_tokens_kd
192
+ config["domain_knowledge"]["query_counts"] = query_counts
193
+ config["output"]["selected_sections"] = selected_sections
194
+ config["output"]["model"] = model
195
+ config["output"]["template"] = paper_template
196
+ config["output"]["prompts_mode"] = prompts_mode
197
+
198
  if openai_api_key is not None:
199
  openai.api_key = openai_api_key
200
  try:
 
202
  except Exception as e:
203
  raise gr.Error(f"Key错误. Error: {e}")
204
  try:
205
+ output = generator_wrapper(config)
 
 
 
 
 
206
  if cache_mode:
207
  from utils.storage import upload_file
208
  upload_file(output, target_name=file_name_upload)
 
218
  with gr.Column(scale=2):
219
  key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
220
  visible=not IS_OPENAI_API_KEY_AVAILABLE)
 
 
221
  # 每个功能做一个tab
222
  with gr.Tab("学术论文"):
223
  gr.Markdown(ACADEMIC_PAPER)
 
242
  interactive=GPT4_INTERACTIVE,
243
  info="生成论文用到的语言模型.")
244
  prompts_mode = gr.Checkbox(value=False, visible=True, interactive=True,
245
+ label="Prompts模式",
246
+ info="只输出用于生成论文的Prompts, 可以复制到别的地方生成论文.")
247
 
248
  sections = gr.CheckboxGroup(
249
  choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
 
257
 
258
  with gr.Column(scale=2):
259
  max_kw_ref_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
260
+ interactive=True, label="MAX_KW_REFS",
261
+ info="每个Keyword搜索几篇参考文献", visible=False)
262
 
263
  max_tokens_ref_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
264
+ interactive=True, label="MAX_TOKENS",
265
+ info="参考文献内容占用Prompts中的Token数")
266
 
267
  tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
268
  info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
269
  interactive=True)
270
+
271
+ text_ref = gr.Textbox(lines=5, label="References (Optional)", visible=True,
272
+ info="交给AI参考的文献的标题, 用英文逗号`,`隔开.")
273
+
274
+ gr.Examples(
275
+ examples = ["Understanding the Impact of Model Incoherence on Convergence of Incremental SGD with Random Reshuffle,"
276
+ "Variance-Reduced Off-Policy TDC Learning: Non-Asymptotic Convergence Analysis,"
277
+ "Greedy-GQ with Variance Reduction: Finite-time Analysis and Improved Complexity"],
278
+ inputs=text_ref,
279
+ cache_examples=False
280
+ )
281
 
282
  with gr.Row():
283
  with gr.Column(scale=1):
 
285
 
286
  with gr.Column(scale=2):
287
  query_counts_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
288
+ interactive=True, label="QUERY_COUNTS",
289
+ info="从知识库内检索多少条内容", visible=False)
290
  max_tokens_kd_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
291
+ interactive=True, label="MAX_TOKENS",
292
+ info="知识库内容占用Prompts中的Token数")
293
  domain_knowledge = gr.Dropdown(label="预载知识库",
294
  choices=ALL_DATABASES,
295
  value="(None)",
 
314
  json_output = gr.JSON(label="References")
315
  clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
316
  submit_button_pp.click(fn=wrapped_generator,
317
+ inputs=[title, description_pp, key,
318
+ tldr_checkbox, max_kw_ref_slider, text_ref, max_tokens_ref_slider,
319
  domain_knowledge, max_tokens_kd_slider, query_counts_slider,
320
  template, sections, model_selection, prompts_mode], outputs=file_output)
321
 
assets/idealab.png DELETED
Binary file (52.1 kB)
 
auto_backgrounds.py → auto_generators.py RENAMED
@@ -40,7 +40,7 @@ def log_usage(usage, generating_target, print_out=True):
40
 
41
 
42
  def _generation_setup(title, description="", template="ICLR2022",
43
- tldr=False, max_kw_refs=10, bib_refs=None, max_tokens_ref=2048, # generating references
44
  knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
45
  debug=True):
46
  """
@@ -115,7 +115,7 @@ def _generation_setup(title, description="", template="ICLR2022",
115
 
116
  print("Keywords: \n", keywords)
117
  # todo: in some rare situations, collected papers will be an empty list. handle this issue
118
- ref = References(title, bib_refs)
119
  ref.collect_papers(keywords, tldr=tldr)
120
  references = ref.to_prompts(max_tokens=max_tokens_ref)
121
  all_paper_ids = ref.to_bibtex(bibtex_path)
@@ -200,7 +200,7 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
200
 
201
 
202
  def generate_draft(title, description="", # main input
203
- tldr=True, max_kw_refs=10, bib_refs=None, max_tokens_ref=2048, # references
204
  knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
205
  sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
206
  ):
@@ -245,7 +245,7 @@ def generate_draft(title, description="", # main input
245
  "abstract"]
246
  else:
247
  sections = _filter_sections(sections)
248
- paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, bib_refs,
249
  max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
250
  query_counts=query_counts,
251
  knowledge_database=knowledge_database)
@@ -254,11 +254,10 @@ def generate_draft(title, description="", # main input
254
  prompts_dict = {}
255
  print(f"================PROCESSING================")
256
  for section in sections:
 
 
257
  if prompts_mode:
258
- prompts = generate_paper_prompts(paper, section)
259
- prompts_dict[section] = prompts
260
  continue
261
-
262
  print(f"Generate {section} part...")
263
  max_attempts = 4
264
  attempts_count = 0
@@ -274,21 +273,16 @@ def generate_draft(title, description="", # main input
274
  logging.info(message)
275
  attempts_count += 1
276
  time.sleep(15)
277
-
278
  # post-processing
279
  print("================POST-PROCESSING================")
280
  create_copies(destination_folder)
281
- input_dict = {"title": title, "description": description, "generator": "generate_draft"}
282
- filename = hash_name(input_dict) + ".zip"
 
283
  print("\nMission completed.\n")
 
284
 
285
- if prompts_mode:
286
- filename = hash_name(input_dict) + ".json"
287
- with open(filename, "w") as f:
288
- json.dump(prompts_dict, f)
289
- return filename
290
- else:
291
- return make_archive(destination_folder, filename)
292
 
293
 
294
  if __name__ == "__main__":
 
40
 
41
 
42
  def _generation_setup(title, description="", template="ICLR2022",
43
+ tldr=False, max_kw_refs=10, refs=None, max_tokens_ref=2048, # generating references
44
  knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
45
  debug=True):
46
  """
 
115
 
116
  print("Keywords: \n", keywords)
117
  # todo: in some rare situations, collected papers will be an empty list. handle this issue
118
+ ref = References(title, load_papers=refs)
119
  ref.collect_papers(keywords, tldr=tldr)
120
  references = ref.to_prompts(max_tokens=max_tokens_ref)
121
  all_paper_ids = ref.to_bibtex(bibtex_path)
 
200
 
201
 
202
  def generate_draft(title, description="", # main input
203
+ tldr=True, max_kw_refs=10, refs=None, max_tokens_ref=2048, # references
204
  knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
205
  sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
206
  ):
 
245
  "abstract"]
246
  else:
247
  sections = _filter_sections(sections)
248
+ paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, refs,
249
  max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
250
  query_counts=query_counts,
251
  knowledge_database=knowledge_database)
 
254
  prompts_dict = {}
255
  print(f"================PROCESSING================")
256
  for section in sections:
257
+ prompts = generate_paper_prompts(paper, section)
258
+ prompts_dict[section] = prompts
259
  if prompts_mode:
 
 
260
  continue
 
261
  print(f"Generate {section} part...")
262
  max_attempts = 4
263
  attempts_count = 0
 
273
  logging.info(message)
274
  attempts_count += 1
275
  time.sleep(15)
 
276
  # post-processing
277
  print("================POST-PROCESSING================")
278
  create_copies(destination_folder)
279
+ filename = "prompts.json"
280
+ with open(os.path.join(destination_folder, filename), "w") as f:
281
+ json.dump(prompts_dict, f)
282
  print("\nMission completed.\n")
283
+ return destination_folder
284
 
285
+ # return make_archive(destination_folder, filename)
 
 
 
 
 
 
286
 
287
 
288
  if __name__ == "__main__":
configurations/default.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ date: 2023-07-11
2
+
3
+ generator: "auto_draft"
4
+
5
+ paper:
6
+ title: "playing atari game with deep reinforcement learning"
7
+ description: ""
8
+
9
+ references:
10
+ tldr: True
11
+ max_kw_refs: 10
12
+ max_tokens_ref: 2048
13
+ refs: null
14
+
15
+ domain_knowledge:
16
+ knowledge_database: null
17
+ max_tokens_kd: 2048
18
+ query_counts: 10
19
+
20
+ output:
21
+ template: "default"
22
+ model: "gpt-4"
23
+ selected_sections: null
24
+ prompts_mode: False
25
+
26
+
27
+
28
+
29
+
cyber-supervisor-openai.py CHANGED
@@ -3,7 +3,7 @@ import openai
3
  import ast
4
  from tools import functions, TOOLS
5
 
6
- MAX_ITER = 5
7
 
8
  openai.api_key = os.getenv("OPENAI_API_KEY")
9
  default_model = os.getenv("DEFAULT_MODEL")
 
3
  import ast
4
  from tools import functions, TOOLS
5
 
6
+ MAX_ITER = 99
7
 
8
  openai.api_key = os.getenv("OPENAI_API_KEY")
9
  default_model = os.getenv("DEFAULT_MODEL")
idealab.py DELETED
@@ -1,144 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import openai
4
- from utils.references import References
5
- from utils.gpt_interaction import GPTModel
6
- from utils.prompts import SYSTEM
7
-
8
- openai_key = os.getenv("OPENAI_API_KEY")
9
- default_model = os.getenv("DEFAULT_MODEL")
10
- if default_model is None:
11
- # default_model = "gpt-3.5-turbo-16k"
12
- default_model = "gpt-4"
13
-
14
- openai.api_key = openai_key
15
-
16
- paper_system_prompt = '''You are an assistant designed to propose choices of research direction.
17
- The user will input questions or some keywords of a fields. You need to generate some paper titles and main contributions. Ensure follow the following instructions:
18
- Instruction:
19
- - Your response should follow the JSON format.
20
- - Your response should have the following structure:
21
- {
22
- "your suggested paper title":
23
- {
24
- "summary": "an overview introducing what this paper will include",
25
- "contributions": {
26
- "contribution1": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
27
- "contribution2": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
28
- ...
29
- }
30
- }
31
- "your suggested paper title":
32
- {
33
- "summary": "an overview introducing what this paper will include",
34
- "contributions": {
35
- "contribution1": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
36
- "contribution2": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
37
- ...
38
- }
39
- }
40
- ...
41
- }
42
- - Please list three to five suggested title and at least three contributions for each paper.
43
- '''
44
-
45
-
46
- contribution_system_prompt = '''You are an assistant designed to criticize the contributions of a paper. You will be provided Paper's Title, References and Contributions. Ensure follow the following instructions:
47
- Instruction:
48
- - Your response should follow the JSON format.
49
- - Your response should have the following structure:
50
- {
51
- "title": "the title provided by the user",
52
- "comment": "your thoughts on if this title clearly reflects the key ideas of this paper and explain why"
53
- "contributions": {
54
- "contribution1": {"statement": "briefly describe what the contribution is",
55
- "reason": "reason why the user claims it is a contribution",
56
- "judge": "your thought about if this is a novel contribution and explain why",
57
- "suggestion": "your suggestion on how to modify the research direction to enhance the novelty "},
58
- "contribution2": {"statement": "briefly describe what the contribution is",
59
- "reason": "reason why the user claims it is a contribution",
60
- "judge": "your thought about if this is a novel contribution and explain why",
61
- "suggestion": "your suggestion on how to modify the research direction to enhance the novelty "},
62
- ...
63
- }
64
- }
65
- - You need to carefully check if the claimed contribution has been made in the provided references, which makes the contribution not novel.
66
- - You also need to propose your concerns on if any of contributions could be incremental or just a mild modification on an existing work.
67
- '''
68
-
69
-
70
- ANNOUNCEMENT = """
71
- <h1 style="text-align: center"><img src='/file=assets/idealab.png' width=36px style="display: inline"/>灵感实验室IdeaLab</h1>
72
-
73
- <p>灵感实验室IdeaLab可以为你选择你下一篇论文的研究方向! 输入你的研究领域或者任何想法, 灵感实验室会自动生成若干个论文标题+论文的主要贡献供你选择. </p>
74
-
75
- <p>除此之外, 输入你的论文标题+主要贡献, 它会自动搜索相关文献, 来验证这个想法是不是有人做过了.</p>
76
- """
77
-
78
-
79
- def criticize_my_idea(title, contributions, max_tokens=4096):
80
- ref = References(title=title, description=f"{contributions}")
81
- keywords, _ = llm(systems=SYSTEM["keywords"], prompts=title, return_json=True)
82
- keywords = {keyword: 10 for keyword in keywords}
83
- ref.collect_papers(keywords)
84
- ref_prompt = ref.to_prompts(max_tokens=max_tokens)
85
-
86
- prompt = f"Title: {title}\n References: {ref_prompt}\n Contributions: {contributions}"
87
- output, _ = llm(systems=contribution_system_prompt, prompts=prompt, return_json=True)
88
- return output, ref_prompt
89
-
90
- def paste_title(suggestions):
91
- if suggestions:
92
- title = suggestions['title']['new title']
93
- contributions = suggestions['contributions']
94
-
95
- return title, contributions, {}, {}, {}
96
- else:
97
- return "", "", {}, {}, {}
98
-
99
- def generate_choices(thoughts):
100
- output, _ = llm(systems=paper_system_prompt, prompts=thoughts, return_json=True)
101
- return output
102
-
103
-
104
- # def translate_json(json_input):
105
- # system_prompt = "You are a translation bot. The user will input a JSON format string. You need to translate it into Chinese and return in the same formmat."
106
- # output, _ = llm(systems=system_prompt, prompts=str(json_input), return_json=True)
107
- # return output
108
-
109
-
110
- with gr.Blocks() as demo:
111
- llm = GPTModel(model=default_model)
112
-
113
- gr.HTML(ANNOUNCEMENT)
114
- with gr.Row():
115
- with gr.Tab("生成论文想法 (Generate Paper Ideas)"):
116
- thoughts_input = gr.Textbox(label="Thoughts")
117
- with gr.Accordion("Show prompts", open=False):
118
- prompts_1 = gr.Textbox(label="Prompts", interactive=False, value=paper_system_prompt)
119
-
120
- with gr.Row():
121
- button_generate_idea = gr.Button("Make it an idea!", variant="primary")
122
-
123
- with gr.Tab("验证想法可行性 (Validate Feasibility)"):
124
- title_input = gr.Textbox(label="Title")
125
- contribution_input = gr.Textbox(label="Contributions", lines=5)
126
- with gr.Accordion("Show prompts", open=False):
127
- prompts_2 = gr.Textbox(label="Prompts", interactive=False, value=contribution_system_prompt)
128
-
129
- with gr.Row():
130
- button_submit = gr.Button("Criticize my idea!", variant="primary")
131
-
132
- with gr.Tab("生成论文 (Generate Paper)"):
133
- gr.Markdown("...")
134
-
135
- with gr.Column(scale=1):
136
- contribution_output = gr.JSON(label="Contributions")
137
- # cn_output = gr.JSON(label="主要贡献")
138
- with gr.Accordion("References", open=False):
139
- references_output = gr.JSON(label="References")
140
-
141
- button_submit.click(fn=criticize_my_idea, inputs=[title_input, contribution_input], outputs=[contribution_output, references_output])
142
- button_generate_idea.click(fn=generate_choices, inputs=thoughts_input, outputs=contribution_output)#.success(translate_json, contribution_output, cn_output)
143
- demo.queue(concurrency_count=1, max_size=5, api_open=False)
144
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kdb_test.py CHANGED
@@ -6,11 +6,15 @@ import gradio as gr
6
  import os
7
  import json
8
  from models import EMBEDDINGS
 
 
 
9
 
10
- # todo: 功能还没做
 
11
 
12
- HF_TOKEN = None # os.getenv("HF_TOKEN")
13
- REPO_ID = None # os.getenv("KDB_REPO")
14
  if HF_TOKEN is not None and REPO_ID is not None:
15
  snapshot_download(REPO_ID, repo_type="dataset", local_dir="knowledge_databases/",
16
  local_dir_use_symlinks=False, token=HF_TOKEN)
@@ -50,6 +54,29 @@ def query_from_kdb(input, kdb, query_counts):
50
  raise RuntimeError(f"Failed to query from FAISS.")
51
  return domain_knowledge, ""
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  with gr.Blocks() as demo:
54
  with gr.Row():
55
  with gr.Column():
@@ -76,9 +103,16 @@ with gr.Blocks() as demo:
76
  interactive=True, label="QUERY_COUNTS",
77
  info="How many contents will be retrieved from the vector database.")
78
 
79
- retrieval_output = gr.JSON(label="Output")
 
 
80
 
81
- button_retrieval.click(fn=query_from_kdb, inputs=[user_input, kdb_dropdown, query_counts_slider], outputs=[retrieval_output, user_input])
 
 
 
 
 
82
 
83
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
84
  demo.launch(show_error=True)
 
6
  import os
7
  import json
8
  from models import EMBEDDINGS
9
+ from utils.gpt_interaction import GPTModel
10
+ from utils.prompts import SYSTEM
11
+ import openai
12
 
13
+ llm = GPTModel(model="gpt-3.5-turbo")
14
+ openai.api_key = os.getenv("OPENAI_API_KEY")
15
 
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+ REPO_ID = os.getenv("KDB_REPO")
18
  if HF_TOKEN is not None and REPO_ID is not None:
19
  snapshot_download(REPO_ID, repo_type="dataset", local_dir="knowledge_databases/",
20
  local_dir_use_symlinks=False, token=HF_TOKEN)
 
54
  raise RuntimeError(f"Failed to query from FAISS.")
55
  return domain_knowledge, ""
56
 
57
+ def query_from_kdb_llm(title, contributions, kdb, query_counts):
58
+ if kdb == "(None)":
59
+ return {"knowledge_database": "(None)", "title": title, "contributions": contributions, "output": ""}, "", {}
60
+
61
+ db_path = f"knowledge_databases/{kdb}"
62
+ db_config_path = os.path.join(db_path, "db_meta.json")
63
+ db_index_path = os.path.join(db_path, "faiss_index")
64
+ if os.path.isdir(db_path):
65
+ # load configuration file
66
+ with open(db_config_path, "r", encoding="utf-8") as f:
67
+ db_config = json.load(f)
68
+ model_name = db_config["embedding_model"]
69
+ embeddings = EMBEDDINGS[model_name]
70
+ db = FAISS.load_local(db_index_path, embeddings)
71
+ knowledge = Knowledge(db=db)
72
+ prompts = f"Title: {title}\n Contributions: {contributions}"
73
+ preliminaries_kw, _ = llm(systems=SYSTEM["preliminaries"], prompts=prompts, return_json=True)
74
+ knowledge.collect_knowledge(preliminaries_kw, max_query=query_counts)
75
+ domain_knowledge = knowledge.to_json()
76
+ else:
77
+ raise RuntimeError(f"Failed to query from FAISS.")
78
+ return domain_knowledge, "", preliminaries_kw
79
+
80
  with gr.Blocks() as demo:
81
  with gr.Row():
82
  with gr.Column():
 
103
  interactive=True, label="QUERY_COUNTS",
104
  info="How many contents will be retrieved from the vector database.")
105
 
106
+ with gr.Column():
107
+ retrieval_output = gr.JSON(label="Output")
108
+ llm_kws = gr.JSON(label="Keywords generated by LLM")
109
 
110
+ button_retrieval.click(fn=query_from_kdb,
111
+ inputs=[user_input, kdb_dropdown, query_counts_slider],
112
+ outputs=[retrieval_output, user_input])
113
+ button_retrieval_2.click(fn=query_from_kdb_llm,
114
+ inputs=[title_input, contribution_input, kdb_dropdown, query_counts_slider],
115
+ outputs=[retrieval_output, user_input, llm_kws])
116
 
117
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
118
  demo.launch(show_error=True)
references_generator.py DELETED
@@ -1,86 +0,0 @@
1
- '''
2
- This script is used to generate the most relevant papers of a given title.
3
- - Search for as many as possible references. For 10~15 keywords, 10 references each.
4
- - Sort the results from most relevant to least relevant.
5
- - Return the most relevant using token size.
6
-
7
- Note: we do not use this function in auto-draft function. It has been integrated in that.
8
- '''
9
-
10
- import os.path
11
- import json
12
- from utils.references import References
13
- from section_generator import keywords_generation # section_generation_bg, #, figures_generation, section_generation
14
- import itertools
15
- from gradio_client import Client
16
-
17
-
18
- def generate_raw_references(title, description="",
19
- bib_refs=None, tldr=False, max_kw_refs=10,
20
- save_to="ref.bib"):
21
- # load pre-provided references
22
- ref = References(title, bib_refs)
23
-
24
- # generate multiple keywords for searching
25
- input_dict = {"title": title, "description": description}
26
- keywords, usage = keywords_generation(input_dict)
27
- keywords = list(keywords)
28
- comb_keywords = list(itertools.combinations(keywords, 2))
29
- for comb_keyword in comb_keywords:
30
- keywords.append(" ".join(comb_keyword))
31
- keywords = {keyword:max_kw_refs for keyword in keywords}
32
- print(f"keywords: {keywords}\n\n")
33
-
34
- ref.collect_papers(keywords, tldr=tldr)
35
- paper_json = ref.to_json()
36
-
37
- with open(save_to, "w") as f:
38
- json.dump(paper_json, f)
39
-
40
- return save_to, ref # paper_json
41
-
42
- def generate_top_k_references(title, description="",
43
- bib_refs=None, tldr=False, max_kw_refs=10, save_to="ref.bib", top_k=5):
44
- json_path, ref_raw = generate_raw_references(title, description, bib_refs, tldr, max_kw_refs, save_to)
45
- json_content = ref_raw.to_json()
46
-
47
- client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
48
- result = client.predict(
49
- title, # str in 'Title' Textbox component
50
- json_path, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
51
- top_k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
52
- api_name="/get_k_relevant_papers"
53
- )
54
- with open(result) as f:
55
- result = json.load(f)
56
- return result
57
-
58
-
59
- if __name__ == "__main__":
60
- import openai
61
- openai.api_key = os.getenv("OPENAI_API_KEY")
62
-
63
- title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
64
- description = ""
65
- save_to = "paper.json"
66
- save_to, paper_json = generate_raw_references(title, description, save_to=save_to)
67
-
68
- print("`paper.json` has been generated. Now evaluating its similarity...")
69
-
70
- k = 5
71
- client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
72
- result = client.predict(
73
- title, # str in 'Title' Textbox component
74
- save_to, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
75
- k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
76
- api_name="/get_k_relevant_papers"
77
- )
78
-
79
- with open(result) as f:
80
- result = json.load(f)
81
-
82
- print(result)
83
-
84
- save_to = "paper2.json"
85
- with open(save_to, "w") as f:
86
- json.dump(result, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/knowledge.py CHANGED
@@ -16,7 +16,7 @@ class Knowledge:
16
  self.db = db
17
  self.contents = []
18
 
19
- def collect_knowledge(self, keywords_dict, max_query):
20
  """
21
  keywords_dict:
22
  {"machine learning": 5, "language model": 2};
 
16
  self.db = db
17
  self.contents = []
18
 
19
+ def collect_knowledge(self, keywords_dict: dict, max_query: int):
20
  """
21
  keywords_dict:
22
  {"machine learning": 5, "language model": 2};
utils/references.py CHANGED
@@ -3,52 +3,68 @@
3
  #
4
  # Generate references:
5
  # `Reference` class:
6
- # 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
 
 
7
  # 2. Given some keywords; use Semantic Scholar API to find papers.
8
  # 3. Generate bibtex from the selected papers. --> to_bibtex()
9
  # 4. Generate prompts from the selected papers: --> to_prompts()
10
  # A sample prompt: {"paper_id": "paper summary"}
 
11
 
12
- # todo: (1) citations & citedby of provided papers:
13
- # load the pre-defined papers; use S2 to find all related works
14
- # add all citations to `bib_papers`
15
- # add all citedby to `bib_papers`
16
- # use Semantic Scholar to find their embeddings
17
- # (2) separate references:
18
- # divide references into different groups to reduce the tokens count
19
- # for generating different paragraph of related works, use different set of references
20
- from typing import Dict, List
21
- import requests
22
  import re
 
 
 
 
23
  import bibtexparser
24
- import random
25
- from scholarly import scholarly
26
- from scholarly import ProxyGenerator
27
- import tiktoken
28
- import itertools, uuid, json
29
- from gradio_client import Client
30
- import time
31
  import numpy as np
 
 
32
  from numpy.linalg import norm
 
 
33
 
34
-
35
  URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
36
  MAX_BATCH_SIZE = 16
37
  MAX_ATTEMPTS = 20
38
 
 
 
 
 
 
39
  ######################################################################################################################
40
  # Some basic tools
41
  ######################################################################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def evaluate_cosine_similarity(v1, v2):
43
  try:
44
- return np.dot(v1, v2)/(norm(v1)*norm(v2))
45
  except ValueError:
46
  return 0.0
47
 
 
48
  def chunks(lst, chunk_size=MAX_BATCH_SIZE):
49
  """Splits a longer list to respect batch size"""
50
  for i in range(0, len(lst), chunk_size):
51
- yield lst[i : i + chunk_size]
 
52
 
53
  def embed(papers):
54
  embeddings_by_paper_id: Dict[str, List[float]] = {}
@@ -64,6 +80,7 @@ def embed(papers):
64
 
65
  return embeddings_by_paper_id
66
 
 
67
  def get_embeddings(paper_title, paper_description):
68
  output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
69
  emb_vector = embed(output)["target_paper"]
@@ -71,9 +88,17 @@ def get_embeddings(paper_title, paper_description):
71
  target_paper["embeddings"] = emb_vector
72
  return target_paper
73
 
 
 
 
 
 
 
 
74
  def get_top_k(papers_dict, paper_title, paper_description, k=None):
 
75
  target_paper = get_embeddings(paper_title, paper_description)
76
- papers = papers_dict # must include embeddings
77
 
78
  # if k < len(papers_json), return k most relevant papers
79
  # if k >= len(papers_json) or k is None, return all papers
@@ -88,7 +113,7 @@ def get_top_k(papers_dict, paper_title, paper_description, k=None):
88
  for k in papers:
89
  v = papers[k]
90
  embedding_vector = v["embeddings"]
91
- cos_sim = evaluate_cosine_similarity(embedding_vector, target_embedding_vector)
92
  papers[k]["cos_sim"] = cos_sim
93
 
94
  # return the best k papers
@@ -97,14 +122,6 @@ def get_top_k(papers_dict, paper_title, paper_description, k=None):
97
  sorted_papers[key].pop("embeddings", None)
98
  return sorted_papers
99
 
100
- def remove_newlines(serie):
101
- # This function is applied to the abstract of each paper to reduce the length of prompts.
102
- serie = serie.replace('\n', ' ')
103
- serie = serie.replace('\\n', ' ')
104
- serie = serie.replace(' ', ' ')
105
- serie = serie.replace(' ', ' ')
106
- return serie
107
-
108
 
109
  def search_paper_abstract(title):
110
  pg = ProxyGenerator()
@@ -123,6 +140,159 @@ def search_paper_abstract(title):
123
  return remove_newlines(found_paper['bib']['abstract'])
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def load_papers_from_bibtex(bib_file_path):
127
  with open(bib_file_path) as bibtex_file:
128
  bib_database = bibtexparser.load(bibtex_file)
@@ -154,15 +324,20 @@ def load_papers_from_bibtex(bib_file_path):
154
  bib_papers.append(result)
155
  return bib_papers
156
 
157
- # `tokenizer`: used to count how many tokens
158
- tokenizer_name = tiktoken.encoding_for_model('gpt-4')
159
- tokenizer = tiktoken.get_encoding(tokenizer_name.name)
160
-
161
 
162
- def tiktoken_len(text):
163
- # evaluate how many tokens for the given text
164
- tokens = tokenizer.encode(text, disallowed_special=())
165
- return len(tokens)
 
 
 
 
 
 
 
 
 
166
 
167
 
168
  ######################################################################################################################
@@ -174,7 +349,7 @@ def ss_search(keywords, limit=20, fields=None):
174
  fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"]
175
  keywords = keywords.lower()
176
  keywords = keywords.replace(" ", "+")
177
- url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)}'
178
  # headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
179
  headers = {"Accept": "*/*"}
180
 
@@ -183,27 +358,6 @@ def ss_search(keywords, limit=20, fields=None):
183
 
184
 
185
  def _collect_papers_ss(keyword, counts=3, tldr=False):
186
- def externalIds2link(externalIds):
187
- # Sample externalIds:
188
- # "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}"
189
- if externalIds:
190
- # Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI
191
- # priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral)
192
- # DBLP
193
- dblp_id = externalIds.get('DBLP')
194
- if dblp_id is not None:
195
- dblp_link = f"dblp.org/rec/{dblp_id}"
196
- return dblp_link
197
- # arXiv
198
- arxiv_id = externalIds.get('ArXiv')
199
- if arxiv_id is not None:
200
- arxiv_link = f"arxiv.org/abs/{arxiv_id}"
201
- return arxiv_link
202
- return ""
203
- else:
204
- # if this is an empty dictionary, return an empty string
205
- return ""
206
-
207
  def extract_paper_id(last_name, year_str, title):
208
  pattern = r'^\w+'
209
  words = re.findall(pattern, title)
@@ -289,24 +443,28 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
289
  ######################################################################################################################
290
 
291
  class References:
292
- def __init__(self, title, load_papers=None, keyword="customized_refs", description=""):
 
 
 
 
 
 
 
 
293
  if load_papers is not None:
294
- self.papers = {keyword: load_papers_from_bibtex(load_papers)}
295
- else:
296
- self.papers = {}
297
  self.title = title
298
  self.description = description
299
 
300
- def load_papers(self, bibtex, keyword):
301
- self.papers[keyword] = load_papers_from_bibtex(bibtex)
302
-
303
- def generate_keywords_dict(self):
304
  keywords_dict = {}
305
  for k in self.papers:
306
  keywords_dict[k] = len(self.papers[k])
307
  return keywords_dict
308
 
309
- def collect_papers(self, keywords_dict, tldr=False):
310
  """
311
  Collect as many papers as possible
312
 
@@ -320,21 +478,15 @@ class References:
320
  keywords.append(" ".join(comb_keyword))
321
  for key in keywords:
322
  self.papers[key] = _collect_papers_ss(key, 10, tldr)
323
- # print("Collected papers: ", papers)
324
- # for key, counts in keywords_dict.items():
325
- # self.papers[key] = _collect_papers_ss(key, counts, tldr)
326
 
327
- def to_bibtex(self, path_to_bibtex="ref.bib"):
328
  """
329
  Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
330
  """
331
- # todo:
332
- # use embeddings to evaluate; keep top k relevant references in papers
333
- # send (title, .bib file) to evaluate embeddings; recieve truncated papers
334
  papers = self._get_papers(keyword="_all")
335
 
336
- l = len(papers)
337
- print(f"{l} papers will be added to `ref.bib`.")
338
  # clear the bibtex file
339
  with open(path_to_bibtex, "w", encoding="utf-8") as file:
340
  file.write("")
@@ -372,7 +524,7 @@ class References:
372
  papers = self.papers["keyword"]
373
  return papers
374
 
375
- def to_prompts(self, keyword="_all", max_tokens=2048):
376
  # `prompts`:
377
  # {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
378
  # this will be used to instruct GPT model to cite the correct bibtex entry.
@@ -384,21 +536,11 @@ class References:
384
  papers_json = self.to_json()
385
  with open(json_path, "w") as f:
386
  json.dump(papers_json, f)
387
-
388
  try:
389
  # Use external API to obtain the most relevant papers
390
  title = self.title
391
  description = self.description
392
  result = get_top_k(papers_json, title, description)
393
- # client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
394
- # result = client.predict(
395
- # title, # str in 'Title' Textbox component
396
- # json_path, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
397
- # 50, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
398
- # api_name="/get_k_relevant_papers"
399
- # )
400
- # with open(result) as f:
401
- # result = json.load(f)
402
  result = [item for key, item in result.items()]
403
  except Exception as e:
404
  print(f"Error occurs during calling external API: {e}\n")
@@ -417,54 +559,9 @@ class References:
417
  break
418
  return prompts
419
 
420
- def to_json(self, keyword="_all"):
421
  papers = self._get_papers(keyword)
422
  papers_json = {}
423
  for paper in papers:
424
  papers_json[paper["paper_id"]] = paper
425
  return papers_json
426
-
427
-
428
- if __name__ == "__main__":
429
- # testing search results
430
- print("================Testing `ss_search`================")
431
- r = ss_search("Deep Q-Networks", limit=1) # a list of raw papers
432
- if r['total'] > 0:
433
- paper = r['data'][0]
434
- # print(paper)
435
-
436
- # resting References
437
- print("================Testing `References`================")
438
- refs = References(title="Super Deep Q-Networks")
439
- keywords_dict = {
440
- "Deep Q-Networks": 5,
441
- "Actor-Critic Algorithms": 4,
442
- "Exploration-Exploitation Trade-off": 3
443
- }
444
- print("================Testing `References.collect_papers`================")
445
- refs.collect_papers(keywords_dict, tldr=True)
446
- for k in refs.papers:
447
- papers = refs.papers[k] # for each keyword, there is a list of papers
448
- print("keyword: ", k)
449
- for paper in papers:
450
- print(paper["paper_id"])
451
-
452
- print("================Testing `References.to_bibtex`================")
453
- refs.to_bibtex()
454
-
455
- print("================Testing `References.to_json`================")
456
- papers_json = refs.to_json() # this json can be used to find the most relevant papers
457
- with open("papers.json", "w", encoding='utf-8') as text_file:
458
- text_file.write(f"{papers_json}")
459
-
460
- print("================Testing `References.to_prompts`================")
461
- prompts = refs.to_prompts()
462
- print(prompts)
463
-
464
- # bib = "test.bib"
465
- # refs.load_papers(bib, "variance-reduction rl")
466
- # print(refs.papers)
467
- #
468
- # prompts = refs.to_prompts()
469
- # for k in prompts:
470
- # print(f"{k}: {prompts[k]}\n")
 
3
  #
4
  # Generate references:
5
  # `Reference` class:
6
+ # 1. Two methods to load papers:
7
+ # 1.1. Read a given string including paper titles separated by `,`
8
+ # 1.2. Read a .bib file
9
  # 2. Given some keywords; use Semantic Scholar API to find papers.
10
  # 3. Generate bibtex from the selected papers. --> to_bibtex()
11
  # 4. Generate prompts from the selected papers: --> to_prompts()
12
  # A sample prompt: {"paper_id": "paper summary"}
13
+ # 5. Generate json from the selected papers. --> to_json()
14
 
15
+ import itertools
16
+ import json
 
 
 
 
 
 
 
 
17
  import re
18
+ import uuid
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import arxiv
22
  import bibtexparser
 
 
 
 
 
 
 
23
  import numpy as np
24
+ import requests
25
+ import tiktoken
26
  from numpy.linalg import norm
27
+ from scholarly import ProxyGenerator
28
+ from scholarly import scholarly
29
 
30
+ # used to evaluate embeddings
31
  URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
32
  MAX_BATCH_SIZE = 16
33
  MAX_ATTEMPTS = 20
34
 
35
+ # `tokenizer`: used to count how many tokens
36
+ tokenizer_name = tiktoken.encoding_for_model('gpt-4')
37
+ tokenizer = tiktoken.get_encoding(tokenizer_name.name)
38
+
39
+
40
  ######################################################################################################################
41
  # Some basic tools
42
  ######################################################################################################################
43
+ def remove_special_characters(s):
44
+ return ''.join(c for c in s if c.isalnum() or c.isspace() or c == ',')
45
+
46
+
47
+ def remove_newlines(serie):
48
+ # This function is applied to the abstract of each paper to reduce the length of prompts.
49
+ serie = serie.replace('\n', ' ')
50
+ serie = serie.replace('\\n', ' ')
51
+ serie = serie.replace(' ', ' ')
52
+ serie = serie.replace(' ', ' ')
53
+ return serie
54
+
55
+
56
  def evaluate_cosine_similarity(v1, v2):
57
  try:
58
+ return np.dot(v1, v2) / (norm(v1) * norm(v2))
59
  except ValueError:
60
  return 0.0
61
 
62
+
63
  def chunks(lst, chunk_size=MAX_BATCH_SIZE):
64
  """Splits a longer list to respect batch size"""
65
  for i in range(0, len(lst), chunk_size):
66
+ yield lst[i: i + chunk_size]
67
+
68
 
69
  def embed(papers):
70
  embeddings_by_paper_id: Dict[str, List[float]] = {}
 
80
 
81
  return embeddings_by_paper_id
82
 
83
+
84
  def get_embeddings(paper_title, paper_description):
85
  output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
86
  emb_vector = embed(output)["target_paper"]
 
88
  target_paper["embeddings"] = emb_vector
89
  return target_paper
90
 
91
+
92
+ def get_embeddings_vector(paper_title, paper_description):
93
+ output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
94
+ emb_vector = embed(output)["target_paper"]
95
+ return emb_vector
96
+
97
+
98
  def get_top_k(papers_dict, paper_title, paper_description, k=None):
99
+ # returns the top k papers most similar to the target paper
100
  target_paper = get_embeddings(paper_title, paper_description)
101
+ papers = papers_dict # must include embeddings
102
 
103
  # if k < len(papers_json), return k most relevant papers
104
  # if k >= len(papers_json) or k is None, return all papers
 
113
  for k in papers:
114
  v = papers[k]
115
  embedding_vector = v["embeddings"]
116
+ cos_sim = evaluate_cosine_similarity(embedding_vector, target_embedding_vector)
117
  papers[k]["cos_sim"] = cos_sim
118
 
119
  # return the best k papers
 
122
  sorted_papers[key].pop("embeddings", None)
123
  return sorted_papers
124
 
 
 
 
 
 
 
 
 
125
 
126
  def search_paper_abstract(title):
127
  pg = ProxyGenerator()
 
140
  return remove_newlines(found_paper['bib']['abstract'])
141
 
142
 
143
+ def tiktoken_len(text):
144
+ # evaluate how many tokens for the given text
145
+ tokens = tokenizer.encode(text, disallowed_special=())
146
+ return len(tokens)
147
+
148
+
149
+ ######################################################################################################################
150
+ # Academic search tools
151
+ ######################################################################################################################
152
+ def externalIds2link(externalIds):
153
+ # Sample externalIds:
154
+ # "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}"
155
+ if externalIds:
156
+ # Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI
157
+ # priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral)
158
+ # DBLP
159
+ dblp_id = externalIds.get('DBLP')
160
+ if dblp_id is not None:
161
+ dblp_link = f"dblp.org/rec/{dblp_id}"
162
+ return dblp_link
163
+ # arXiv
164
+ arxiv_id = externalIds.get('ArXiv')
165
+ if arxiv_id is not None:
166
+ arxiv_link = f"arxiv.org/abs/{arxiv_id}"
167
+ return arxiv_link
168
+ return ""
169
+ else:
170
+ # if this is an empty dictionary, return an empty string
171
+ return ""
172
+
173
+
174
+ def search_paper_arxiv(title):
175
+ search = arxiv.Search(
176
+ query=title,
177
+ max_results=1,
178
+ sort_by=arxiv.SortCriterion.Relevance
179
+ )
180
+ try:
181
+ # (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal (8) embeddings
182
+ result = next(search.results())
183
+ title = result.title
184
+ authors = " and ".join([author.name for author in result.authors])
185
+ year = str(result.updated.now().year)
186
+ link = result.pdf_url
187
+ abstract = result.summary
188
+ journal = f"Arxiv: {result.entry_id}"
189
+ paper_id = result.authors[0].name.replace(" ", "")[:4] + year + title[:6].replace(" ", "")
190
+ paper_id = paper_id.lower()
191
+
192
+ paper = {"paper_id": paper_id,
193
+ "title": title,
194
+ "authors": authors,
195
+ "year": year,
196
+ "link": link,
197
+ "abstract": abstract,
198
+ "journal": journal}
199
+ except StopIteration:
200
+ paper = {}
201
+ return paper
202
+
203
+
204
+ def search_paper_ss(title):
205
+ fields = ["title", "abstract", "venue", "year", "authors", "tldr", "externalIds"]
206
+ limit = 1
207
+ url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={title}&limit={limit}&fields={",".join(fields)}'
208
+ # headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
209
+ headers = {"Accept": "*/*"}
210
+ response = requests.get(url, headers=headers, timeout=30)
211
+ results = response.json()
212
+ if results['total'] == 0:
213
+ return {}
214
+ raw_paper = results['data'][0]
215
+ if raw_paper['tldr'] is not None:
216
+ abstract = raw_paper['tldr']['text']
217
+ elif raw_paper['abstract'] is not None:
218
+ abstract = remove_newlines(raw_paper['abstract'])
219
+ else:
220
+ abstract = ""
221
+
222
+ authors = [author['name'] for author in raw_paper['authors']]
223
+ authors_str = " and ".join(authors)
224
+ year_str = str(raw_paper['year'])
225
+ title = raw_paper['title']
226
+
227
+ paper_id = authors_str.replace(" ", "")[:4] + year_str + title[:6].replace(" ", "")
228
+
229
+ # some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
230
+ journal = remove_special_characters(raw_paper['venue'])
231
+ if not journal:
232
+ journal = "arXiv preprint"
233
+ link = externalIds2link(raw_paper['externalIds'])
234
+ paper = {
235
+ "paper_id": paper_id,
236
+ "title": title,
237
+ "abstract": abstract,
238
+ "link": link,
239
+ "authors": authors_str,
240
+ "year": year_str,
241
+ "journal": journal
242
+ }
243
+ return paper
244
+
245
+
246
+ def search_paper_scrape(title):
247
+ pg = ProxyGenerator()
248
+ success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
249
+ if success:
250
+ try:
251
+ scholarly.use_proxy(pg)
252
+ # input the title of a paper, return its abstract
253
+ search_query = scholarly.search_pubs(title)
254
+ found_paper = next(search_query)
255
+ url = found_paper['pub_url']
256
+
257
+ result = found_paper['bib']
258
+
259
+ title = result['title']
260
+ authors = " and ".join(result['author'])
261
+ year = str(result['pub_year'])
262
+ journal = result['pub_year']
263
+ abstract = result['abstract']
264
+
265
+ paper_id = authors.replace(" ", "")[:4] + year + title[:6].replace(" ", "")
266
+ paper = {
267
+ "paper_id": paper_id,
268
+ "title": title,
269
+ "abstract": abstract,
270
+ "link": url,
271
+ "authors": authors,
272
+ "year": year,
273
+ "journal": journal
274
+ }
275
+ return paper
276
+ except StopIteration:
277
+ return {}
278
+
279
+
280
+ def search_paper(title, verbose=True):
281
+ if verbose:
282
+ print(f"Searching {title}...")
283
+ # try Semantic Scholar first
284
+ paper = search_paper_ss(title)
285
+ if not paper:
286
+ paper = search_paper_arxiv(title)
287
+ if not paper:
288
+ paper = search_paper_scrape(title)
289
+ if paper:
290
+ paper["embeddings"] = get_embeddings_vector(paper_title=paper['title'], paper_description=paper['abstract'])
291
+ if verbose:
292
+ print(f"Search result: {paper}.")
293
+ return paper
294
+
295
+
296
  def load_papers_from_bibtex(bib_file_path):
297
  with open(bib_file_path) as bibtex_file:
298
  bib_database = bibtexparser.load(bibtex_file)
 
324
  bib_papers.append(result)
325
  return bib_papers
326
 
 
 
 
 
327
 
328
+ def load_papers_from_text(text):
329
+ # split text by comma
330
+ titles = [part.strip() for part in text.split(',')]
331
+ titles = [remove_special_characters(title) for title in titles]
332
+ papers = []
333
+ if len(titles) > 0:
334
+ for title in titles:
335
+ paper = search_paper(title)
336
+ if paper:
337
+ papers.append(paper)
338
+ return papers
339
+ else:
340
+ return []
341
 
342
 
343
  ######################################################################################################################
 
349
  fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"]
350
  keywords = keywords.lower()
351
  keywords = keywords.replace(" ", "+")
352
+ url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)} '
353
  # headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
354
  headers = {"Accept": "*/*"}
355
 
 
358
 
359
 
360
  def _collect_papers_ss(keyword, counts=3, tldr=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  def extract_paper_id(last_name, year_str, title):
362
  pattern = r'^\w+'
363
  words = re.findall(pattern, title)
 
443
  ######################################################################################################################
444
 
445
  class References:
446
+ def __init__(self,
447
+ title: str,
448
+ load_papers: Optional[str] = None,
449
+ load_bibtex: Optional[str] = None,
450
+ description: str = ""
451
+ ):
452
+ self.papers = {}
453
+ if load_bibtex is not None:
454
+ self.papers["load_from_bibtex"] = load_papers_from_bibtex(load_bibtex)
455
  if load_papers is not None:
456
+ self.papers["load_from_text"] = load_papers_from_text(load_papers)
457
+
 
458
  self.title = title
459
  self.description = description
460
 
461
+ def generate_keywords_dict(self) -> Dict[str, int]:
 
 
 
462
  keywords_dict = {}
463
  for k in self.papers:
464
  keywords_dict[k] = len(self.papers[k])
465
  return keywords_dict
466
 
467
+ def collect_papers(self, keywords_dict: Dict[str, int], tldr: bool = False) -> None:
468
  """
469
  Collect as many papers as possible
470
 
 
478
  keywords.append(" ".join(comb_keyword))
479
  for key in keywords:
480
  self.papers[key] = _collect_papers_ss(key, 10, tldr)
 
 
 
481
 
482
+ def to_bibtex(self, path_to_bibtex: str = "ref.bib") -> List[str]:
483
  """
484
  Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
485
  """
 
 
 
486
  papers = self._get_papers(keyword="_all")
487
 
488
+ num_papers = len(papers)
489
+ print(f"{num_papers} papers will be added to `ref.bib`.")
490
  # clear the bibtex file
491
  with open(path_to_bibtex, "w", encoding="utf-8") as file:
492
  file.write("")
 
524
  papers = self.papers["keyword"]
525
  return papers
526
 
527
+ def to_prompts(self, keyword: str = "_all", max_tokens: int = 2048):
528
  # `prompts`:
529
  # {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
530
  # this will be used to instruct GPT model to cite the correct bibtex entry.
 
536
  papers_json = self.to_json()
537
  with open(json_path, "w") as f:
538
  json.dump(papers_json, f)
 
539
  try:
540
  # Use external API to obtain the most relevant papers
541
  title = self.title
542
  description = self.description
543
  result = get_top_k(papers_json, title, description)
 
 
 
 
 
 
 
 
 
544
  result = [item for key, item in result.items()]
545
  except Exception as e:
546
  print(f"Error occurs during calling external API: {e}\n")
 
559
  break
560
  return prompts
561
 
562
+ def to_json(self, keyword: str = "_all"):
563
  papers = self._get_papers(keyword)
564
  papers_json = {}
565
  for paper in papers:
566
  papers_json[paper["paper_id"]] = paper
567
  return papers_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
worker.py CHANGED
@@ -3,7 +3,7 @@ This script is only used for service-side host.
3
  '''
4
  import boto3
5
  import os, time
6
- from api_wrapper import generator_wrapper
7
  from sqlalchemy import create_engine, Table, MetaData, update, select
8
  from sqlalchemy.orm import sessionmaker
9
  from sqlalchemy import inspect
 
3
  '''
4
  import boto3
5
  import os, time
6
+ from wrapper import generator_wrapper
7
  from sqlalchemy import create_engine, Table, MetaData, update, select
8
  from sqlalchemy.orm import sessionmaker
9
  from sqlalchemy import inspect
wrapper.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is used to wrap all generation methods together.
3
+
4
+ todo:
5
+ A worker keeps running on the server. Monitor the Amazon SQS. Once receive a new message, do the following:
6
+ Download the corresponding configuration files on S3.
7
+ Change Task status from Pending to Running.
8
+ Call `generator_wrapper` and wait for the outputs.
9
+ If `generator_wrapper` returns results:
10
+ evaluate the results; compile it; upload results to S3 ... Change Task status from Running to Completed.
11
+ If anything goes wrong, raise Error.
12
+ If `generator_wrapper` returns nothing or Timeout, or raise any error:
13
+ Change Task status from Running to Failed.
14
+ """
15
+ from auto_generators import generate_draft
16
+ from utils.file_operations import make_archive
17
+ import yaml
18
+ import uuid
19
+
20
+
21
+ def remove_special_characters(s):
22
+ return ''.join(c for c in s if c.isalnum() or c.isspace() or c == ',')
23
+
24
+
25
+ def generator_wrapper(config):
26
+ if not isinstance(config, dict):
27
+ with open(config, "r") as file:
28
+ config = yaml.safe_load(file)
29
+ title = config["paper"]["title"]
30
+ generator = config["generator"]
31
+ if generator == "auto_draft":
32
+ folder = generate_draft(title, config["paper"]["description"],
33
+ tldr=config["references"]["tldr"],
34
+ max_kw_refs=config["references"]["max_kw_refs"],
35
+ refs=config["references"]["refs"],
36
+ max_tokens_ref=config["references"]["max_tokens_ref"],
37
+ knowledge_database=config["domain_knowledge"]["knowledge_database"],
38
+ max_tokens_kd=config["domain_knowledge"]["max_tokens_kd"],
39
+ query_counts=config["domain_knowledge"]["query_counts"],
40
+ sections=config["output"]["selected_sections"],
41
+ model=config["output"]["model"],
42
+ template=config["output"]["template"],
43
+ prompts_mode=config["output"]["prompts_mode"],
44
+ )
45
+ else:
46
+ raise NotImplementedError(f"The generator {generator} has not been supported yet.")
47
+ # todo: post processing: translate to Chinese, compile PDF ...
48
+ filename = remove_special_characters(title).replace(" ", "_") + uuid.uuid1().hex + ".zip"
49
+ return make_archive(folder, filename)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ pass
54
+ # with open("configurations/default.yaml", 'r') as file:
55
+ # config = yaml.safe_load(file)
56
+ # print(config)
57
+ # generator_wrapper(config)