Spaces:
Running
Running
shaocongma
commited on
Commit
·
94dc00e
1
Parent(s):
328e8d0
Add a generator wrapper using configuration file. Edit the logic of searching references. Add Gradio UI for testing Knowledge database.
Browse files- api_wrapper.py +0 -42
- app.py +54 -36
- assets/idealab.png +0 -0
- auto_backgrounds.py → auto_generators.py +11 -17
- configurations/default.yaml +29 -0
- cyber-supervisor-openai.py +1 -1
- idealab.py +0 -144
- kdb_test.py +39 -5
- references_generator.py +0 -86
- utils/knowledge.py +1 -1
- utils/references.py +233 -136
- worker.py +1 -1
- wrapper.py +57 -0
api_wrapper.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
This script is used to wrap all generation methods together.
|
3 |
-
|
4 |
-
todo:
|
5 |
-
A worker keeps running on the server. Monitor the Amazon SQS. Once receive a new message, do the following:
|
6 |
-
Download the corresponding configuration files on S3.
|
7 |
-
Change Task status from Pending to Running.
|
8 |
-
Call `generator_wrapper` and wait for the outputs.
|
9 |
-
If `generator_wrapper` returns results:
|
10 |
-
evaluate the results; compile it; upload results to S3 ... Change Task status from Running to Completed.
|
11 |
-
If anything goes wrong, raise Error.
|
12 |
-
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
13 |
-
Change Task status from Running to Failed.
|
14 |
-
'''
|
15 |
-
import os.path
|
16 |
-
|
17 |
-
from auto_backgrounds import generate_draft
|
18 |
-
import json, time
|
19 |
-
from utils.file_operations import make_archive
|
20 |
-
|
21 |
-
|
22 |
-
GENERATOR_MAPPING = {"fake": None, # a fake generator
|
23 |
-
"draft": generate_draft # generate academic paper
|
24 |
-
}
|
25 |
-
|
26 |
-
def generator_wrapper(config):
|
27 |
-
generator = GENERATOR_MAPPING[config["generator"]]
|
28 |
-
|
29 |
-
|
30 |
-
def generator_wrapper_from_json(path_to_config_json):
|
31 |
-
# Read configuration file and call corresponding function
|
32 |
-
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
33 |
-
config = json.load(f)
|
34 |
-
print("Configuration:", config)
|
35 |
-
# generator = GENERATOR_MAPPING.get(config["generator"])
|
36 |
-
generator = None
|
37 |
-
if generator is None:
|
38 |
-
# generate a fake ZIP file and upload
|
39 |
-
time.sleep(150)
|
40 |
-
zip_path = os.path.splitext(path_to_config_json)[0]+".zip"
|
41 |
-
return make_archive(path_to_config_json, zip_path)
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -2,9 +2,10 @@ import uuid
|
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
import openai
|
5 |
-
|
6 |
from utils.file_operations import list_folders, urlify
|
7 |
from huggingface_hub import snapshot_download
|
|
|
8 |
|
9 |
# todo:
|
10 |
# 6. get logs when the procedure is not completed. *
|
@@ -22,8 +23,10 @@ from huggingface_hub import snapshot_download
|
|
22 |
# OPENAI_API_BASE: (Optional) Support alternative OpenAI minors
|
23 |
# GPT4_ENABLE: (Optional) Set it to 1 to enable GPT-4 model.
|
24 |
|
25 |
-
# AWS_ACCESS_KEY_ID: (Optional)
|
26 |
-
#
|
|
|
|
|
27 |
# KDB_REPO: (Optional) A Huggingface dataset hosting Knowledge Databases
|
28 |
# HF_TOKEN: (Optional) Access to KDB_REPO
|
29 |
|
@@ -34,7 +37,7 @@ openai_key = os.getenv("OPENAI_API_KEY")
|
|
34 |
openai_api_base = os.getenv("OPENAI_API_BASE")
|
35 |
if openai_api_base is not None:
|
36 |
openai.api_base = openai_api_base
|
37 |
-
GPT4_ENABLE = os.getenv("GPT4_ENABLE")
|
38 |
|
39 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
40 |
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
@@ -124,7 +127,7 @@ REFERENCES = """## 一键搜索相关论文
|
|
124 |
REFERENCES_INSTRUCTION = """### References
|
125 |
这一栏用于定义AI如何选取参考文献. 目前是两种方式混合:
|
126 |
1. GPT自动根据标题生成关键字,使用Semantic Scholar搜索引擎搜索文献,利用Specter获取Paper Embedding来自动选取最相关的文献作为GPT的参考资料.
|
127 |
-
2.
|
128 |
关于有希望利用本地文件来供GPT参考的功能将在未来实装.
|
129 |
"""
|
130 |
|
@@ -140,7 +143,7 @@ OUTPUTS_INSTRUCTION = """### Outputs
|
|
140 |
这一栏用于定义输出的内容:
|
141 |
* Template: 用于填装内容的LaTeX模板.
|
142 |
* Models: 使用GPT-4或者GPT-3.5-Turbo生成内容.
|
143 |
-
* Prompts模式: 不生成内容, 而是生成用于生成内容的Prompts. 可以手动复制到网页版或者其他语言模型中进行使用.
|
144 |
"""
|
145 |
|
146 |
OTHERS_INSTRUCTION = """### Others
|
@@ -164,18 +167,34 @@ def clear_inputs(*args):
|
|
164 |
def clear_inputs_refs(*args):
|
165 |
return "", 5
|
166 |
|
|
|
167 |
def wrapped_generator(
|
168 |
paper_title, paper_description, # main input
|
169 |
-
openai_api_key=None,
|
170 |
-
tldr=True, max_kw_refs=10,
|
171 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
172 |
paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False, # outputs parameters
|
173 |
cache_mode=IS_CACHE_AVAILABLE # handle cache mode
|
174 |
):
|
175 |
-
# if `cache_mode` is True, then always upload the generated content to my S3.
|
176 |
file_name_upload = urlify(paper_title) + "_" + uuid.uuid1().hex + ".zip"
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
if openai_api_key is not None:
|
180 |
openai.api_key = openai_api_key
|
181 |
try:
|
@@ -183,12 +202,7 @@ def wrapped_generator(
|
|
183 |
except Exception as e:
|
184 |
raise gr.Error(f"Key错误. Error: {e}")
|
185 |
try:
|
186 |
-
output =
|
187 |
-
paper_title, description=paper_description, # main input
|
188 |
-
tldr=tldr, max_kw_refs=max_kw_refs, bib_refs=bib_refs, max_tokens_ref=max_tokens_ref, # references
|
189 |
-
knowledge_database=knowledge_database, max_tokens_kd=max_tokens_kd, query_counts=query_counts, # domain knowledge
|
190 |
-
sections=selected_sections, model=model, template=paper_template, prompts_mode=prompts_mode, # outputs parameters
|
191 |
-
)
|
192 |
if cache_mode:
|
193 |
from utils.storage import upload_file
|
194 |
upload_file(output, target_name=file_name_upload)
|
@@ -204,8 +218,6 @@ with gr.Blocks(theme=theme) as demo:
|
|
204 |
with gr.Column(scale=2):
|
205 |
key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
|
206 |
visible=not IS_OPENAI_API_KEY_AVAILABLE)
|
207 |
-
url = gr.Textbox(value=None, lines=1, max_lines=1, label="URL",
|
208 |
-
visible=False)
|
209 |
# 每个功能做一个tab
|
210 |
with gr.Tab("学术论文"):
|
211 |
gr.Markdown(ACADEMIC_PAPER)
|
@@ -230,8 +242,8 @@ with gr.Blocks(theme=theme) as demo:
|
|
230 |
interactive=GPT4_INTERACTIVE,
|
231 |
info="生成论文用到的语言模型.")
|
232 |
prompts_mode = gr.Checkbox(value=False, visible=True, interactive=True,
|
233 |
-
|
234 |
-
|
235 |
|
236 |
sections = gr.CheckboxGroup(
|
237 |
choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
|
@@ -245,21 +257,27 @@ with gr.Blocks(theme=theme) as demo:
|
|
245 |
|
246 |
with gr.Column(scale=2):
|
247 |
max_kw_ref_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
248 |
-
|
249 |
-
|
250 |
|
251 |
max_tokens_ref_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
252 |
-
|
253 |
-
|
254 |
|
255 |
tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
|
256 |
info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
|
257 |
interactive=True)
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
with gr.Row():
|
265 |
with gr.Column(scale=1):
|
@@ -267,11 +285,11 @@ with gr.Blocks(theme=theme) as demo:
|
|
267 |
|
268 |
with gr.Column(scale=2):
|
269 |
query_counts_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
270 |
-
|
271 |
-
|
272 |
max_tokens_kd_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
273 |
-
|
274 |
-
|
275 |
domain_knowledge = gr.Dropdown(label="预载知识库",
|
276 |
choices=ALL_DATABASES,
|
277 |
value="(None)",
|
@@ -296,8 +314,8 @@ with gr.Blocks(theme=theme) as demo:
|
|
296 |
json_output = gr.JSON(label="References")
|
297 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
298 |
submit_button_pp.click(fn=wrapped_generator,
|
299 |
-
inputs=[title, description_pp, key,
|
300 |
-
tldr_checkbox, max_kw_ref_slider,
|
301 |
domain_knowledge, max_tokens_kd_slider, query_counts_slider,
|
302 |
template, sections, model_selection, prompts_mode], outputs=file_output)
|
303 |
|
|
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
import openai
|
5 |
+
import yaml
|
6 |
from utils.file_operations import list_folders, urlify
|
7 |
from huggingface_hub import snapshot_download
|
8 |
+
from wrapper import generator_wrapper
|
9 |
|
10 |
# todo:
|
11 |
# 6. get logs when the procedure is not completed. *
|
|
|
23 |
# OPENAI_API_BASE: (Optional) Support alternative OpenAI minors
|
24 |
# GPT4_ENABLE: (Optional) Set it to 1 to enable GPT-4 model.
|
25 |
|
26 |
+
# AWS_ACCESS_KEY_ID: (Optional)
|
27 |
+
# Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
|
28 |
+
# AWS_SECRET_ACCESS_KEY: (Optional)
|
29 |
+
# Access AWS cloud storage (you need to edit `BUCKET_NAME` in `utils/storage.py` if you need to use this function)
|
30 |
# KDB_REPO: (Optional) A Huggingface dataset hosting Knowledge Databases
|
31 |
# HF_TOKEN: (Optional) Access to KDB_REPO
|
32 |
|
|
|
37 |
openai_api_base = os.getenv("OPENAI_API_BASE")
|
38 |
if openai_api_base is not None:
|
39 |
openai.api_base = openai_api_base
|
40 |
+
GPT4_ENABLE = os.getenv("GPT4_ENABLE") # disable GPT-4 for public repo
|
41 |
|
42 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
43 |
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
|
|
127 |
REFERENCES_INSTRUCTION = """### References
|
128 |
这一栏用于定义AI如何选取参考文献. 目前是两种方式混合:
|
129 |
1. GPT自动根据标题生成关键字,使用Semantic Scholar搜索引擎搜索文献,利用Specter获取Paper Embedding来自动选取最相关的文献作为GPT的参考资料.
|
130 |
+
2. 用户通过输入文章标题(用英文逗号隔开), AI会自动搜索文献作为参考资料.
|
131 |
关于有希望利用本地文件来供GPT参考的功能将在未来实装.
|
132 |
"""
|
133 |
|
|
|
143 |
这一栏用于定义输出的内容:
|
144 |
* Template: 用于填装内容的LaTeX模板.
|
145 |
* Models: 使用GPT-4或者GPT-3.5-Turbo生成内容.
|
146 |
+
* Prompts模式: 不生成内容, 而是生成用于生成内容的Prompts. 可以手动复制到网页版或者其他语言模型中进行使用. (放在输出的ZIP文件的prompts.json文件中)
|
147 |
"""
|
148 |
|
149 |
OTHERS_INSTRUCTION = """### Others
|
|
|
167 |
def clear_inputs_refs(*args):
|
168 |
return "", 5
|
169 |
|
170 |
+
|
171 |
def wrapped_generator(
|
172 |
paper_title, paper_description, # main input
|
173 |
+
openai_api_key=None, # key
|
174 |
+
tldr=True, max_kw_refs=10, refs=None, max_tokens_ref=2048, # references
|
175 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
176 |
paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False, # outputs parameters
|
177 |
cache_mode=IS_CACHE_AVAILABLE # handle cache mode
|
178 |
):
|
|
|
179 |
file_name_upload = urlify(paper_title) + "_" + uuid.uuid1().hex + ".zip"
|
180 |
+
|
181 |
+
# load the default configuration file
|
182 |
+
with open("configurations/default.yaml", 'r') as file:
|
183 |
+
config = yaml.safe_load(file)
|
184 |
+
config["paper"]["title"] = paper_title
|
185 |
+
config["paper"]["description"] = paper_description
|
186 |
+
config["references"]["tldr"] = tldr
|
187 |
+
config["references"]["max_kw_refs"] = max_kw_refs
|
188 |
+
config["references"]["refs"] = refs
|
189 |
+
config["references"]["max_tokens_ref"] = max_tokens_ref
|
190 |
+
config["domain_knowledge"]["knowledge_database"] = knowledge_database
|
191 |
+
config["domain_knowledge"]["max_tokens_kd"] = max_tokens_kd
|
192 |
+
config["domain_knowledge"]["query_counts"] = query_counts
|
193 |
+
config["output"]["selected_sections"] = selected_sections
|
194 |
+
config["output"]["model"] = model
|
195 |
+
config["output"]["template"] = paper_template
|
196 |
+
config["output"]["prompts_mode"] = prompts_mode
|
197 |
+
|
198 |
if openai_api_key is not None:
|
199 |
openai.api_key = openai_api_key
|
200 |
try:
|
|
|
202 |
except Exception as e:
|
203 |
raise gr.Error(f"Key错误. Error: {e}")
|
204 |
try:
|
205 |
+
output = generator_wrapper(config)
|
|
|
|
|
|
|
|
|
|
|
206 |
if cache_mode:
|
207 |
from utils.storage import upload_file
|
208 |
upload_file(output, target_name=file_name_upload)
|
|
|
218 |
with gr.Column(scale=2):
|
219 |
key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
|
220 |
visible=not IS_OPENAI_API_KEY_AVAILABLE)
|
|
|
|
|
221 |
# 每个功能做一个tab
|
222 |
with gr.Tab("学术论文"):
|
223 |
gr.Markdown(ACADEMIC_PAPER)
|
|
|
242 |
interactive=GPT4_INTERACTIVE,
|
243 |
info="生成论文用到的语言模型.")
|
244 |
prompts_mode = gr.Checkbox(value=False, visible=True, interactive=True,
|
245 |
+
label="Prompts模式",
|
246 |
+
info="只输出用于生成论文的Prompts, 可以复制到别的地方生成论文.")
|
247 |
|
248 |
sections = gr.CheckboxGroup(
|
249 |
choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
|
|
|
257 |
|
258 |
with gr.Column(scale=2):
|
259 |
max_kw_ref_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
260 |
+
interactive=True, label="MAX_KW_REFS",
|
261 |
+
info="每个Keyword搜索几篇参考文献", visible=False)
|
262 |
|
263 |
max_tokens_ref_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
264 |
+
interactive=True, label="MAX_TOKENS",
|
265 |
+
info="参考文献内容占用Prompts中的Token数")
|
266 |
|
267 |
tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
|
268 |
info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
|
269 |
interactive=True)
|
270 |
+
|
271 |
+
text_ref = gr.Textbox(lines=5, label="References (Optional)", visible=True,
|
272 |
+
info="交给AI参考的文献的标题, 用英文逗号`,`隔开.")
|
273 |
+
|
274 |
+
gr.Examples(
|
275 |
+
examples = ["Understanding the Impact of Model Incoherence on Convergence of Incremental SGD with Random Reshuffle,"
|
276 |
+
"Variance-Reduced Off-Policy TDC Learning: Non-Asymptotic Convergence Analysis,"
|
277 |
+
"Greedy-GQ with Variance Reduction: Finite-time Analysis and Improved Complexity"],
|
278 |
+
inputs=text_ref,
|
279 |
+
cache_examples=False
|
280 |
+
)
|
281 |
|
282 |
with gr.Row():
|
283 |
with gr.Column(scale=1):
|
|
|
285 |
|
286 |
with gr.Column(scale=2):
|
287 |
query_counts_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
|
288 |
+
interactive=True, label="QUERY_COUNTS",
|
289 |
+
info="从知识库内检索多少条内容", visible=False)
|
290 |
max_tokens_kd_slider = gr.Slider(minimum=256, maximum=8192, value=2048, step=2,
|
291 |
+
interactive=True, label="MAX_TOKENS",
|
292 |
+
info="知识库内容占用Prompts中的Token数")
|
293 |
domain_knowledge = gr.Dropdown(label="预载知识库",
|
294 |
choices=ALL_DATABASES,
|
295 |
value="(None)",
|
|
|
314 |
json_output = gr.JSON(label="References")
|
315 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
316 |
submit_button_pp.click(fn=wrapped_generator,
|
317 |
+
inputs=[title, description_pp, key,
|
318 |
+
tldr_checkbox, max_kw_ref_slider, text_ref, max_tokens_ref_slider,
|
319 |
domain_knowledge, max_tokens_kd_slider, query_counts_slider,
|
320 |
template, sections, model_selection, prompts_mode], outputs=file_output)
|
321 |
|
assets/idealab.png
DELETED
Binary file (52.1 kB)
|
|
auto_backgrounds.py → auto_generators.py
RENAMED
@@ -40,7 +40,7 @@ def log_usage(usage, generating_target, print_out=True):
|
|
40 |
|
41 |
|
42 |
def _generation_setup(title, description="", template="ICLR2022",
|
43 |
-
tldr=False, max_kw_refs=10,
|
44 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
|
45 |
debug=True):
|
46 |
"""
|
@@ -115,7 +115,7 @@ def _generation_setup(title, description="", template="ICLR2022",
|
|
115 |
|
116 |
print("Keywords: \n", keywords)
|
117 |
# todo: in some rare situations, collected papers will be an empty list. handle this issue
|
118 |
-
ref = References(title,
|
119 |
ref.collect_papers(keywords, tldr=tldr)
|
120 |
references = ref.to_prompts(max_tokens=max_tokens_ref)
|
121 |
all_paper_ids = ref.to_bibtex(bibtex_path)
|
@@ -200,7 +200,7 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
|
|
200 |
|
201 |
|
202 |
def generate_draft(title, description="", # main input
|
203 |
-
tldr=True, max_kw_refs=10,
|
204 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
205 |
sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
|
206 |
):
|
@@ -245,7 +245,7 @@ def generate_draft(title, description="", # main input
|
|
245 |
"abstract"]
|
246 |
else:
|
247 |
sections = _filter_sections(sections)
|
248 |
-
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs,
|
249 |
max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
|
250 |
query_counts=query_counts,
|
251 |
knowledge_database=knowledge_database)
|
@@ -254,11 +254,10 @@ def generate_draft(title, description="", # main input
|
|
254 |
prompts_dict = {}
|
255 |
print(f"================PROCESSING================")
|
256 |
for section in sections:
|
|
|
|
|
257 |
if prompts_mode:
|
258 |
-
prompts = generate_paper_prompts(paper, section)
|
259 |
-
prompts_dict[section] = prompts
|
260 |
continue
|
261 |
-
|
262 |
print(f"Generate {section} part...")
|
263 |
max_attempts = 4
|
264 |
attempts_count = 0
|
@@ -274,21 +273,16 @@ def generate_draft(title, description="", # main input
|
|
274 |
logging.info(message)
|
275 |
attempts_count += 1
|
276 |
time.sleep(15)
|
277 |
-
|
278 |
# post-processing
|
279 |
print("================POST-PROCESSING================")
|
280 |
create_copies(destination_folder)
|
281 |
-
|
282 |
-
|
|
|
283 |
print("\nMission completed.\n")
|
|
|
284 |
|
285 |
-
|
286 |
-
filename = hash_name(input_dict) + ".json"
|
287 |
-
with open(filename, "w") as f:
|
288 |
-
json.dump(prompts_dict, f)
|
289 |
-
return filename
|
290 |
-
else:
|
291 |
-
return make_archive(destination_folder, filename)
|
292 |
|
293 |
|
294 |
if __name__ == "__main__":
|
|
|
40 |
|
41 |
|
42 |
def _generation_setup(title, description="", template="ICLR2022",
|
43 |
+
tldr=False, max_kw_refs=10, refs=None, max_tokens_ref=2048, # generating references
|
44 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
|
45 |
debug=True):
|
46 |
"""
|
|
|
115 |
|
116 |
print("Keywords: \n", keywords)
|
117 |
# todo: in some rare situations, collected papers will be an empty list. handle this issue
|
118 |
+
ref = References(title, load_papers=refs)
|
119 |
ref.collect_papers(keywords, tldr=tldr)
|
120 |
references = ref.to_prompts(max_tokens=max_tokens_ref)
|
121 |
all_paper_ids = ref.to_bibtex(bibtex_path)
|
|
|
200 |
|
201 |
|
202 |
def generate_draft(title, description="", # main input
|
203 |
+
tldr=True, max_kw_refs=10, refs=None, max_tokens_ref=2048, # references
|
204 |
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
|
205 |
sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
|
206 |
):
|
|
|
245 |
"abstract"]
|
246 |
else:
|
247 |
sections = _filter_sections(sections)
|
248 |
+
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, refs,
|
249 |
max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
|
250 |
query_counts=query_counts,
|
251 |
knowledge_database=knowledge_database)
|
|
|
254 |
prompts_dict = {}
|
255 |
print(f"================PROCESSING================")
|
256 |
for section in sections:
|
257 |
+
prompts = generate_paper_prompts(paper, section)
|
258 |
+
prompts_dict[section] = prompts
|
259 |
if prompts_mode:
|
|
|
|
|
260 |
continue
|
|
|
261 |
print(f"Generate {section} part...")
|
262 |
max_attempts = 4
|
263 |
attempts_count = 0
|
|
|
273 |
logging.info(message)
|
274 |
attempts_count += 1
|
275 |
time.sleep(15)
|
|
|
276 |
# post-processing
|
277 |
print("================POST-PROCESSING================")
|
278 |
create_copies(destination_folder)
|
279 |
+
filename = "prompts.json"
|
280 |
+
with open(os.path.join(destination_folder, filename), "w") as f:
|
281 |
+
json.dump(prompts_dict, f)
|
282 |
print("\nMission completed.\n")
|
283 |
+
return destination_folder
|
284 |
|
285 |
+
# return make_archive(destination_folder, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
|
288 |
if __name__ == "__main__":
|
configurations/default.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
date: 2023-07-11
|
2 |
+
|
3 |
+
generator: "auto_draft"
|
4 |
+
|
5 |
+
paper:
|
6 |
+
title: "playing atari game with deep reinforcement learning"
|
7 |
+
description: ""
|
8 |
+
|
9 |
+
references:
|
10 |
+
tldr: True
|
11 |
+
max_kw_refs: 10
|
12 |
+
max_tokens_ref: 2048
|
13 |
+
refs: null
|
14 |
+
|
15 |
+
domain_knowledge:
|
16 |
+
knowledge_database: null
|
17 |
+
max_tokens_kd: 2048
|
18 |
+
query_counts: 10
|
19 |
+
|
20 |
+
output:
|
21 |
+
template: "default"
|
22 |
+
model: "gpt-4"
|
23 |
+
selected_sections: null
|
24 |
+
prompts_mode: False
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
cyber-supervisor-openai.py
CHANGED
@@ -3,7 +3,7 @@ import openai
|
|
3 |
import ast
|
4 |
from tools import functions, TOOLS
|
5 |
|
6 |
-
MAX_ITER =
|
7 |
|
8 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
9 |
default_model = os.getenv("DEFAULT_MODEL")
|
|
|
3 |
import ast
|
4 |
from tools import functions, TOOLS
|
5 |
|
6 |
+
MAX_ITER = 99
|
7 |
|
8 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
9 |
default_model = os.getenv("DEFAULT_MODEL")
|
idealab.py
DELETED
@@ -1,144 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import os
|
3 |
-
import openai
|
4 |
-
from utils.references import References
|
5 |
-
from utils.gpt_interaction import GPTModel
|
6 |
-
from utils.prompts import SYSTEM
|
7 |
-
|
8 |
-
openai_key = os.getenv("OPENAI_API_KEY")
|
9 |
-
default_model = os.getenv("DEFAULT_MODEL")
|
10 |
-
if default_model is None:
|
11 |
-
# default_model = "gpt-3.5-turbo-16k"
|
12 |
-
default_model = "gpt-4"
|
13 |
-
|
14 |
-
openai.api_key = openai_key
|
15 |
-
|
16 |
-
paper_system_prompt = '''You are an assistant designed to propose choices of research direction.
|
17 |
-
The user will input questions or some keywords of a fields. You need to generate some paper titles and main contributions. Ensure follow the following instructions:
|
18 |
-
Instruction:
|
19 |
-
- Your response should follow the JSON format.
|
20 |
-
- Your response should have the following structure:
|
21 |
-
{
|
22 |
-
"your suggested paper title":
|
23 |
-
{
|
24 |
-
"summary": "an overview introducing what this paper will include",
|
25 |
-
"contributions": {
|
26 |
-
"contribution1": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
27 |
-
"contribution2": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
28 |
-
...
|
29 |
-
}
|
30 |
-
}
|
31 |
-
"your suggested paper title":
|
32 |
-
{
|
33 |
-
"summary": "an overview introducing what this paper will include",
|
34 |
-
"contributions": {
|
35 |
-
"contribution1": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
36 |
-
"contribution2": {"statement": "briefly describe this contribution", "reason": "reason why this contribution can make this paper outstanding"},
|
37 |
-
...
|
38 |
-
}
|
39 |
-
}
|
40 |
-
...
|
41 |
-
}
|
42 |
-
- Please list three to five suggested title and at least three contributions for each paper.
|
43 |
-
'''
|
44 |
-
|
45 |
-
|
46 |
-
contribution_system_prompt = '''You are an assistant designed to criticize the contributions of a paper. You will be provided Paper's Title, References and Contributions. Ensure follow the following instructions:
|
47 |
-
Instruction:
|
48 |
-
- Your response should follow the JSON format.
|
49 |
-
- Your response should have the following structure:
|
50 |
-
{
|
51 |
-
"title": "the title provided by the user",
|
52 |
-
"comment": "your thoughts on if this title clearly reflects the key ideas of this paper and explain why"
|
53 |
-
"contributions": {
|
54 |
-
"contribution1": {"statement": "briefly describe what the contribution is",
|
55 |
-
"reason": "reason why the user claims it is a contribution",
|
56 |
-
"judge": "your thought about if this is a novel contribution and explain why",
|
57 |
-
"suggestion": "your suggestion on how to modify the research direction to enhance the novelty "},
|
58 |
-
"contribution2": {"statement": "briefly describe what the contribution is",
|
59 |
-
"reason": "reason why the user claims it is a contribution",
|
60 |
-
"judge": "your thought about if this is a novel contribution and explain why",
|
61 |
-
"suggestion": "your suggestion on how to modify the research direction to enhance the novelty "},
|
62 |
-
...
|
63 |
-
}
|
64 |
-
}
|
65 |
-
- You need to carefully check if the claimed contribution has been made in the provided references, which makes the contribution not novel.
|
66 |
-
- You also need to propose your concerns on if any of contributions could be incremental or just a mild modification on an existing work.
|
67 |
-
'''
|
68 |
-
|
69 |
-
|
70 |
-
ANNOUNCEMENT = """
|
71 |
-
<h1 style="text-align: center"><img src='/file=assets/idealab.png' width=36px style="display: inline"/>灵感实验室IdeaLab</h1>
|
72 |
-
|
73 |
-
<p>灵感实验室IdeaLab可以为你选择你下一篇论文的研究方向! 输入你的研究领域或者任何想法, 灵感实验室会自动生成若干个论文标题+论文的主要贡献供你选择. </p>
|
74 |
-
|
75 |
-
<p>除此之外, 输入你的论文标题+主要贡献, 它会自动搜索相关文献, 来验证这个想法是不是有人做过了.</p>
|
76 |
-
"""
|
77 |
-
|
78 |
-
|
79 |
-
def criticize_my_idea(title, contributions, max_tokens=4096):
|
80 |
-
ref = References(title=title, description=f"{contributions}")
|
81 |
-
keywords, _ = llm(systems=SYSTEM["keywords"], prompts=title, return_json=True)
|
82 |
-
keywords = {keyword: 10 for keyword in keywords}
|
83 |
-
ref.collect_papers(keywords)
|
84 |
-
ref_prompt = ref.to_prompts(max_tokens=max_tokens)
|
85 |
-
|
86 |
-
prompt = f"Title: {title}\n References: {ref_prompt}\n Contributions: {contributions}"
|
87 |
-
output, _ = llm(systems=contribution_system_prompt, prompts=prompt, return_json=True)
|
88 |
-
return output, ref_prompt
|
89 |
-
|
90 |
-
def paste_title(suggestions):
|
91 |
-
if suggestions:
|
92 |
-
title = suggestions['title']['new title']
|
93 |
-
contributions = suggestions['contributions']
|
94 |
-
|
95 |
-
return title, contributions, {}, {}, {}
|
96 |
-
else:
|
97 |
-
return "", "", {}, {}, {}
|
98 |
-
|
99 |
-
def generate_choices(thoughts):
|
100 |
-
output, _ = llm(systems=paper_system_prompt, prompts=thoughts, return_json=True)
|
101 |
-
return output
|
102 |
-
|
103 |
-
|
104 |
-
# def translate_json(json_input):
|
105 |
-
# system_prompt = "You are a translation bot. The user will input a JSON format string. You need to translate it into Chinese and return in the same formmat."
|
106 |
-
# output, _ = llm(systems=system_prompt, prompts=str(json_input), return_json=True)
|
107 |
-
# return output
|
108 |
-
|
109 |
-
|
110 |
-
with gr.Blocks() as demo:
|
111 |
-
llm = GPTModel(model=default_model)
|
112 |
-
|
113 |
-
gr.HTML(ANNOUNCEMENT)
|
114 |
-
with gr.Row():
|
115 |
-
with gr.Tab("生成论文想法 (Generate Paper Ideas)"):
|
116 |
-
thoughts_input = gr.Textbox(label="Thoughts")
|
117 |
-
with gr.Accordion("Show prompts", open=False):
|
118 |
-
prompts_1 = gr.Textbox(label="Prompts", interactive=False, value=paper_system_prompt)
|
119 |
-
|
120 |
-
with gr.Row():
|
121 |
-
button_generate_idea = gr.Button("Make it an idea!", variant="primary")
|
122 |
-
|
123 |
-
with gr.Tab("验证想法可行性 (Validate Feasibility)"):
|
124 |
-
title_input = gr.Textbox(label="Title")
|
125 |
-
contribution_input = gr.Textbox(label="Contributions", lines=5)
|
126 |
-
with gr.Accordion("Show prompts", open=False):
|
127 |
-
prompts_2 = gr.Textbox(label="Prompts", interactive=False, value=contribution_system_prompt)
|
128 |
-
|
129 |
-
with gr.Row():
|
130 |
-
button_submit = gr.Button("Criticize my idea!", variant="primary")
|
131 |
-
|
132 |
-
with gr.Tab("生成论文 (Generate Paper)"):
|
133 |
-
gr.Markdown("...")
|
134 |
-
|
135 |
-
with gr.Column(scale=1):
|
136 |
-
contribution_output = gr.JSON(label="Contributions")
|
137 |
-
# cn_output = gr.JSON(label="主要贡献")
|
138 |
-
with gr.Accordion("References", open=False):
|
139 |
-
references_output = gr.JSON(label="References")
|
140 |
-
|
141 |
-
button_submit.click(fn=criticize_my_idea, inputs=[title_input, contribution_input], outputs=[contribution_output, references_output])
|
142 |
-
button_generate_idea.click(fn=generate_choices, inputs=thoughts_input, outputs=contribution_output)#.success(translate_json, contribution_output, cn_output)
|
143 |
-
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
144 |
-
demo.launch(show_error=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kdb_test.py
CHANGED
@@ -6,11 +6,15 @@ import gradio as gr
|
|
6 |
import os
|
7 |
import json
|
8 |
from models import EMBEDDINGS
|
|
|
|
|
|
|
9 |
|
10 |
-
|
|
|
11 |
|
12 |
-
HF_TOKEN =
|
13 |
-
REPO_ID =
|
14 |
if HF_TOKEN is not None and REPO_ID is not None:
|
15 |
snapshot_download(REPO_ID, repo_type="dataset", local_dir="knowledge_databases/",
|
16 |
local_dir_use_symlinks=False, token=HF_TOKEN)
|
@@ -50,6 +54,29 @@ def query_from_kdb(input, kdb, query_counts):
|
|
50 |
raise RuntimeError(f"Failed to query from FAISS.")
|
51 |
return domain_knowledge, ""
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
with gr.Blocks() as demo:
|
54 |
with gr.Row():
|
55 |
with gr.Column():
|
@@ -76,9 +103,16 @@ with gr.Blocks() as demo:
|
|
76 |
interactive=True, label="QUERY_COUNTS",
|
77 |
info="How many contents will be retrieved from the vector database.")
|
78 |
|
79 |
-
|
|
|
|
|
80 |
|
81 |
-
button_retrieval.click(fn=query_from_kdb,
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
84 |
demo.launch(show_error=True)
|
|
|
6 |
import os
|
7 |
import json
|
8 |
from models import EMBEDDINGS
|
9 |
+
from utils.gpt_interaction import GPTModel
|
10 |
+
from utils.prompts import SYSTEM
|
11 |
+
import openai
|
12 |
|
13 |
+
llm = GPTModel(model="gpt-3.5-turbo")
|
14 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
15 |
|
16 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
17 |
+
REPO_ID = os.getenv("KDB_REPO")
|
18 |
if HF_TOKEN is not None and REPO_ID is not None:
|
19 |
snapshot_download(REPO_ID, repo_type="dataset", local_dir="knowledge_databases/",
|
20 |
local_dir_use_symlinks=False, token=HF_TOKEN)
|
|
|
54 |
raise RuntimeError(f"Failed to query from FAISS.")
|
55 |
return domain_knowledge, ""
|
56 |
|
57 |
+
def query_from_kdb_llm(title, contributions, kdb, query_counts):
|
58 |
+
if kdb == "(None)":
|
59 |
+
return {"knowledge_database": "(None)", "title": title, "contributions": contributions, "output": ""}, "", {}
|
60 |
+
|
61 |
+
db_path = f"knowledge_databases/{kdb}"
|
62 |
+
db_config_path = os.path.join(db_path, "db_meta.json")
|
63 |
+
db_index_path = os.path.join(db_path, "faiss_index")
|
64 |
+
if os.path.isdir(db_path):
|
65 |
+
# load configuration file
|
66 |
+
with open(db_config_path, "r", encoding="utf-8") as f:
|
67 |
+
db_config = json.load(f)
|
68 |
+
model_name = db_config["embedding_model"]
|
69 |
+
embeddings = EMBEDDINGS[model_name]
|
70 |
+
db = FAISS.load_local(db_index_path, embeddings)
|
71 |
+
knowledge = Knowledge(db=db)
|
72 |
+
prompts = f"Title: {title}\n Contributions: {contributions}"
|
73 |
+
preliminaries_kw, _ = llm(systems=SYSTEM["preliminaries"], prompts=prompts, return_json=True)
|
74 |
+
knowledge.collect_knowledge(preliminaries_kw, max_query=query_counts)
|
75 |
+
domain_knowledge = knowledge.to_json()
|
76 |
+
else:
|
77 |
+
raise RuntimeError(f"Failed to query from FAISS.")
|
78 |
+
return domain_knowledge, "", preliminaries_kw
|
79 |
+
|
80 |
with gr.Blocks() as demo:
|
81 |
with gr.Row():
|
82 |
with gr.Column():
|
|
|
103 |
interactive=True, label="QUERY_COUNTS",
|
104 |
info="How many contents will be retrieved from the vector database.")
|
105 |
|
106 |
+
with gr.Column():
|
107 |
+
retrieval_output = gr.JSON(label="Output")
|
108 |
+
llm_kws = gr.JSON(label="Keywords generated by LLM")
|
109 |
|
110 |
+
button_retrieval.click(fn=query_from_kdb,
|
111 |
+
inputs=[user_input, kdb_dropdown, query_counts_slider],
|
112 |
+
outputs=[retrieval_output, user_input])
|
113 |
+
button_retrieval_2.click(fn=query_from_kdb_llm,
|
114 |
+
inputs=[title_input, contribution_input, kdb_dropdown, query_counts_slider],
|
115 |
+
outputs=[retrieval_output, user_input, llm_kws])
|
116 |
|
117 |
demo.queue(concurrency_count=1, max_size=5, api_open=False)
|
118 |
demo.launch(show_error=True)
|
references_generator.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
This script is used to generate the most relevant papers of a given title.
|
3 |
-
- Search for as many as possible references. For 10~15 keywords, 10 references each.
|
4 |
-
- Sort the results from most relevant to least relevant.
|
5 |
-
- Return the most relevant using token size.
|
6 |
-
|
7 |
-
Note: we do not use this function in auto-draft function. It has been integrated in that.
|
8 |
-
'''
|
9 |
-
|
10 |
-
import os.path
|
11 |
-
import json
|
12 |
-
from utils.references import References
|
13 |
-
from section_generator import keywords_generation # section_generation_bg, #, figures_generation, section_generation
|
14 |
-
import itertools
|
15 |
-
from gradio_client import Client
|
16 |
-
|
17 |
-
|
18 |
-
def generate_raw_references(title, description="",
|
19 |
-
bib_refs=None, tldr=False, max_kw_refs=10,
|
20 |
-
save_to="ref.bib"):
|
21 |
-
# load pre-provided references
|
22 |
-
ref = References(title, bib_refs)
|
23 |
-
|
24 |
-
# generate multiple keywords for searching
|
25 |
-
input_dict = {"title": title, "description": description}
|
26 |
-
keywords, usage = keywords_generation(input_dict)
|
27 |
-
keywords = list(keywords)
|
28 |
-
comb_keywords = list(itertools.combinations(keywords, 2))
|
29 |
-
for comb_keyword in comb_keywords:
|
30 |
-
keywords.append(" ".join(comb_keyword))
|
31 |
-
keywords = {keyword:max_kw_refs for keyword in keywords}
|
32 |
-
print(f"keywords: {keywords}\n\n")
|
33 |
-
|
34 |
-
ref.collect_papers(keywords, tldr=tldr)
|
35 |
-
paper_json = ref.to_json()
|
36 |
-
|
37 |
-
with open(save_to, "w") as f:
|
38 |
-
json.dump(paper_json, f)
|
39 |
-
|
40 |
-
return save_to, ref # paper_json
|
41 |
-
|
42 |
-
def generate_top_k_references(title, description="",
|
43 |
-
bib_refs=None, tldr=False, max_kw_refs=10, save_to="ref.bib", top_k=5):
|
44 |
-
json_path, ref_raw = generate_raw_references(title, description, bib_refs, tldr, max_kw_refs, save_to)
|
45 |
-
json_content = ref_raw.to_json()
|
46 |
-
|
47 |
-
client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
|
48 |
-
result = client.predict(
|
49 |
-
title, # str in 'Title' Textbox component
|
50 |
-
json_path, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
|
51 |
-
top_k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
|
52 |
-
api_name="/get_k_relevant_papers"
|
53 |
-
)
|
54 |
-
with open(result) as f:
|
55 |
-
result = json.load(f)
|
56 |
-
return result
|
57 |
-
|
58 |
-
|
59 |
-
if __name__ == "__main__":
|
60 |
-
import openai
|
61 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
62 |
-
|
63 |
-
title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
64 |
-
description = ""
|
65 |
-
save_to = "paper.json"
|
66 |
-
save_to, paper_json = generate_raw_references(title, description, save_to=save_to)
|
67 |
-
|
68 |
-
print("`paper.json` has been generated. Now evaluating its similarity...")
|
69 |
-
|
70 |
-
k = 5
|
71 |
-
client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
|
72 |
-
result = client.predict(
|
73 |
-
title, # str in 'Title' Textbox component
|
74 |
-
save_to, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
|
75 |
-
k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
|
76 |
-
api_name="/get_k_relevant_papers"
|
77 |
-
)
|
78 |
-
|
79 |
-
with open(result) as f:
|
80 |
-
result = json.load(f)
|
81 |
-
|
82 |
-
print(result)
|
83 |
-
|
84 |
-
save_to = "paper2.json"
|
85 |
-
with open(save_to, "w") as f:
|
86 |
-
json.dump(result, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/knowledge.py
CHANGED
@@ -16,7 +16,7 @@ class Knowledge:
|
|
16 |
self.db = db
|
17 |
self.contents = []
|
18 |
|
19 |
-
def collect_knowledge(self, keywords_dict, max_query):
|
20 |
"""
|
21 |
keywords_dict:
|
22 |
{"machine learning": 5, "language model": 2};
|
|
|
16 |
self.db = db
|
17 |
self.contents = []
|
18 |
|
19 |
+
def collect_knowledge(self, keywords_dict: dict, max_query: int):
|
20 |
"""
|
21 |
keywords_dict:
|
22 |
{"machine learning": 5, "language model": 2};
|
utils/references.py
CHANGED
@@ -3,52 +3,68 @@
|
|
3 |
#
|
4 |
# Generate references:
|
5 |
# `Reference` class:
|
6 |
-
# 1.
|
|
|
|
|
7 |
# 2. Given some keywords; use Semantic Scholar API to find papers.
|
8 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
9 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
10 |
# A sample prompt: {"paper_id": "paper summary"}
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
# add all citations to `bib_papers`
|
15 |
-
# add all citedby to `bib_papers`
|
16 |
-
# use Semantic Scholar to find their embeddings
|
17 |
-
# (2) separate references:
|
18 |
-
# divide references into different groups to reduce the tokens count
|
19 |
-
# for generating different paragraph of related works, use different set of references
|
20 |
-
from typing import Dict, List
|
21 |
-
import requests
|
22 |
import re
|
|
|
|
|
|
|
|
|
23 |
import bibtexparser
|
24 |
-
import random
|
25 |
-
from scholarly import scholarly
|
26 |
-
from scholarly import ProxyGenerator
|
27 |
-
import tiktoken
|
28 |
-
import itertools, uuid, json
|
29 |
-
from gradio_client import Client
|
30 |
-
import time
|
31 |
import numpy as np
|
|
|
|
|
32 |
from numpy.linalg import norm
|
|
|
|
|
33 |
|
34 |
-
|
35 |
URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
|
36 |
MAX_BATCH_SIZE = 16
|
37 |
MAX_ATTEMPTS = 20
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
######################################################################################################################
|
40 |
# Some basic tools
|
41 |
######################################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def evaluate_cosine_similarity(v1, v2):
|
43 |
try:
|
44 |
-
return np.dot(v1, v2)/(norm(v1)*norm(v2))
|
45 |
except ValueError:
|
46 |
return 0.0
|
47 |
|
|
|
48 |
def chunks(lst, chunk_size=MAX_BATCH_SIZE):
|
49 |
"""Splits a longer list to respect batch size"""
|
50 |
for i in range(0, len(lst), chunk_size):
|
51 |
-
yield lst[i
|
|
|
52 |
|
53 |
def embed(papers):
|
54 |
embeddings_by_paper_id: Dict[str, List[float]] = {}
|
@@ -64,6 +80,7 @@ def embed(papers):
|
|
64 |
|
65 |
return embeddings_by_paper_id
|
66 |
|
|
|
67 |
def get_embeddings(paper_title, paper_description):
|
68 |
output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
|
69 |
emb_vector = embed(output)["target_paper"]
|
@@ -71,9 +88,17 @@ def get_embeddings(paper_title, paper_description):
|
|
71 |
target_paper["embeddings"] = emb_vector
|
72 |
return target_paper
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
|
|
75 |
target_paper = get_embeddings(paper_title, paper_description)
|
76 |
-
papers = papers_dict
|
77 |
|
78 |
# if k < len(papers_json), return k most relevant papers
|
79 |
# if k >= len(papers_json) or k is None, return all papers
|
@@ -88,7 +113,7 @@ def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
|
88 |
for k in papers:
|
89 |
v = papers[k]
|
90 |
embedding_vector = v["embeddings"]
|
91 |
-
cos_sim
|
92 |
papers[k]["cos_sim"] = cos_sim
|
93 |
|
94 |
# return the best k papers
|
@@ -97,14 +122,6 @@ def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
|
97 |
sorted_papers[key].pop("embeddings", None)
|
98 |
return sorted_papers
|
99 |
|
100 |
-
def remove_newlines(serie):
|
101 |
-
# This function is applied to the abstract of each paper to reduce the length of prompts.
|
102 |
-
serie = serie.replace('\n', ' ')
|
103 |
-
serie = serie.replace('\\n', ' ')
|
104 |
-
serie = serie.replace(' ', ' ')
|
105 |
-
serie = serie.replace(' ', ' ')
|
106 |
-
return serie
|
107 |
-
|
108 |
|
109 |
def search_paper_abstract(title):
|
110 |
pg = ProxyGenerator()
|
@@ -123,6 +140,159 @@ def search_paper_abstract(title):
|
|
123 |
return remove_newlines(found_paper['bib']['abstract'])
|
124 |
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def load_papers_from_bibtex(bib_file_path):
|
127 |
with open(bib_file_path) as bibtex_file:
|
128 |
bib_database = bibtexparser.load(bibtex_file)
|
@@ -154,15 +324,20 @@ def load_papers_from_bibtex(bib_file_path):
|
|
154 |
bib_papers.append(result)
|
155 |
return bib_papers
|
156 |
|
157 |
-
# `tokenizer`: used to count how many tokens
|
158 |
-
tokenizer_name = tiktoken.encoding_for_model('gpt-4')
|
159 |
-
tokenizer = tiktoken.get_encoding(tokenizer_name.name)
|
160 |
-
|
161 |
|
162 |
-
def
|
163 |
-
#
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
|
168 |
######################################################################################################################
|
@@ -174,7 +349,7 @@ def ss_search(keywords, limit=20, fields=None):
|
|
174 |
fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"]
|
175 |
keywords = keywords.lower()
|
176 |
keywords = keywords.replace(" ", "+")
|
177 |
-
url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)}'
|
178 |
# headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
|
179 |
headers = {"Accept": "*/*"}
|
180 |
|
@@ -183,27 +358,6 @@ def ss_search(keywords, limit=20, fields=None):
|
|
183 |
|
184 |
|
185 |
def _collect_papers_ss(keyword, counts=3, tldr=False):
|
186 |
-
def externalIds2link(externalIds):
|
187 |
-
# Sample externalIds:
|
188 |
-
# "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}"
|
189 |
-
if externalIds:
|
190 |
-
# Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI
|
191 |
-
# priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral)
|
192 |
-
# DBLP
|
193 |
-
dblp_id = externalIds.get('DBLP')
|
194 |
-
if dblp_id is not None:
|
195 |
-
dblp_link = f"dblp.org/rec/{dblp_id}"
|
196 |
-
return dblp_link
|
197 |
-
# arXiv
|
198 |
-
arxiv_id = externalIds.get('ArXiv')
|
199 |
-
if arxiv_id is not None:
|
200 |
-
arxiv_link = f"arxiv.org/abs/{arxiv_id}"
|
201 |
-
return arxiv_link
|
202 |
-
return ""
|
203 |
-
else:
|
204 |
-
# if this is an empty dictionary, return an empty string
|
205 |
-
return ""
|
206 |
-
|
207 |
def extract_paper_id(last_name, year_str, title):
|
208 |
pattern = r'^\w+'
|
209 |
words = re.findall(pattern, title)
|
@@ -289,24 +443,28 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
289 |
######################################################################################################################
|
290 |
|
291 |
class References:
|
292 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
if load_papers is not None:
|
294 |
-
self.papers =
|
295 |
-
|
296 |
-
self.papers = {}
|
297 |
self.title = title
|
298 |
self.description = description
|
299 |
|
300 |
-
def
|
301 |
-
self.papers[keyword] = load_papers_from_bibtex(bibtex)
|
302 |
-
|
303 |
-
def generate_keywords_dict(self):
|
304 |
keywords_dict = {}
|
305 |
for k in self.papers:
|
306 |
keywords_dict[k] = len(self.papers[k])
|
307 |
return keywords_dict
|
308 |
|
309 |
-
def collect_papers(self, keywords_dict, tldr=False):
|
310 |
"""
|
311 |
Collect as many papers as possible
|
312 |
|
@@ -320,21 +478,15 @@ class References:
|
|
320 |
keywords.append(" ".join(comb_keyword))
|
321 |
for key in keywords:
|
322 |
self.papers[key] = _collect_papers_ss(key, 10, tldr)
|
323 |
-
# print("Collected papers: ", papers)
|
324 |
-
# for key, counts in keywords_dict.items():
|
325 |
-
# self.papers[key] = _collect_papers_ss(key, counts, tldr)
|
326 |
|
327 |
-
def to_bibtex(self, path_to_bibtex="ref.bib"):
|
328 |
"""
|
329 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
330 |
"""
|
331 |
-
# todo:
|
332 |
-
# use embeddings to evaluate; keep top k relevant references in papers
|
333 |
-
# send (title, .bib file) to evaluate embeddings; recieve truncated papers
|
334 |
papers = self._get_papers(keyword="_all")
|
335 |
|
336 |
-
|
337 |
-
print(f"{
|
338 |
# clear the bibtex file
|
339 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
340 |
file.write("")
|
@@ -372,7 +524,7 @@ class References:
|
|
372 |
papers = self.papers["keyword"]
|
373 |
return papers
|
374 |
|
375 |
-
def to_prompts(self, keyword="_all", max_tokens=2048):
|
376 |
# `prompts`:
|
377 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
378 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
@@ -384,21 +536,11 @@ class References:
|
|
384 |
papers_json = self.to_json()
|
385 |
with open(json_path, "w") as f:
|
386 |
json.dump(papers_json, f)
|
387 |
-
|
388 |
try:
|
389 |
# Use external API to obtain the most relevant papers
|
390 |
title = self.title
|
391 |
description = self.description
|
392 |
result = get_top_k(papers_json, title, description)
|
393 |
-
# client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
|
394 |
-
# result = client.predict(
|
395 |
-
# title, # str in 'Title' Textbox component
|
396 |
-
# json_path, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
|
397 |
-
# 50, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
|
398 |
-
# api_name="/get_k_relevant_papers"
|
399 |
-
# )
|
400 |
-
# with open(result) as f:
|
401 |
-
# result = json.load(f)
|
402 |
result = [item for key, item in result.items()]
|
403 |
except Exception as e:
|
404 |
print(f"Error occurs during calling external API: {e}\n")
|
@@ -417,54 +559,9 @@ class References:
|
|
417 |
break
|
418 |
return prompts
|
419 |
|
420 |
-
def to_json(self, keyword="_all"):
|
421 |
papers = self._get_papers(keyword)
|
422 |
papers_json = {}
|
423 |
for paper in papers:
|
424 |
papers_json[paper["paper_id"]] = paper
|
425 |
return papers_json
|
426 |
-
|
427 |
-
|
428 |
-
if __name__ == "__main__":
|
429 |
-
# testing search results
|
430 |
-
print("================Testing `ss_search`================")
|
431 |
-
r = ss_search("Deep Q-Networks", limit=1) # a list of raw papers
|
432 |
-
if r['total'] > 0:
|
433 |
-
paper = r['data'][0]
|
434 |
-
# print(paper)
|
435 |
-
|
436 |
-
# resting References
|
437 |
-
print("================Testing `References`================")
|
438 |
-
refs = References(title="Super Deep Q-Networks")
|
439 |
-
keywords_dict = {
|
440 |
-
"Deep Q-Networks": 5,
|
441 |
-
"Actor-Critic Algorithms": 4,
|
442 |
-
"Exploration-Exploitation Trade-off": 3
|
443 |
-
}
|
444 |
-
print("================Testing `References.collect_papers`================")
|
445 |
-
refs.collect_papers(keywords_dict, tldr=True)
|
446 |
-
for k in refs.papers:
|
447 |
-
papers = refs.papers[k] # for each keyword, there is a list of papers
|
448 |
-
print("keyword: ", k)
|
449 |
-
for paper in papers:
|
450 |
-
print(paper["paper_id"])
|
451 |
-
|
452 |
-
print("================Testing `References.to_bibtex`================")
|
453 |
-
refs.to_bibtex()
|
454 |
-
|
455 |
-
print("================Testing `References.to_json`================")
|
456 |
-
papers_json = refs.to_json() # this json can be used to find the most relevant papers
|
457 |
-
with open("papers.json", "w", encoding='utf-8') as text_file:
|
458 |
-
text_file.write(f"{papers_json}")
|
459 |
-
|
460 |
-
print("================Testing `References.to_prompts`================")
|
461 |
-
prompts = refs.to_prompts()
|
462 |
-
print(prompts)
|
463 |
-
|
464 |
-
# bib = "test.bib"
|
465 |
-
# refs.load_papers(bib, "variance-reduction rl")
|
466 |
-
# print(refs.papers)
|
467 |
-
#
|
468 |
-
# prompts = refs.to_prompts()
|
469 |
-
# for k in prompts:
|
470 |
-
# print(f"{k}: {prompts[k]}\n")
|
|
|
3 |
#
|
4 |
# Generate references:
|
5 |
# `Reference` class:
|
6 |
+
# 1. Two methods to load papers:
|
7 |
+
# 1.1. Read a given string including paper titles separated by `,`
|
8 |
+
# 1.2. Read a .bib file
|
9 |
# 2. Given some keywords; use Semantic Scholar API to find papers.
|
10 |
# 3. Generate bibtex from the selected papers. --> to_bibtex()
|
11 |
# 4. Generate prompts from the selected papers: --> to_prompts()
|
12 |
# A sample prompt: {"paper_id": "paper summary"}
|
13 |
+
# 5. Generate json from the selected papers. --> to_json()
|
14 |
|
15 |
+
import itertools
|
16 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
import re
|
18 |
+
import uuid
|
19 |
+
from typing import Dict, List, Optional, Union
|
20 |
+
|
21 |
+
import arxiv
|
22 |
import bibtexparser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
import numpy as np
|
24 |
+
import requests
|
25 |
+
import tiktoken
|
26 |
from numpy.linalg import norm
|
27 |
+
from scholarly import ProxyGenerator
|
28 |
+
from scholarly import scholarly
|
29 |
|
30 |
+
# used to evaluate embeddings
|
31 |
URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
|
32 |
MAX_BATCH_SIZE = 16
|
33 |
MAX_ATTEMPTS = 20
|
34 |
|
35 |
+
# `tokenizer`: used to count how many tokens
|
36 |
+
tokenizer_name = tiktoken.encoding_for_model('gpt-4')
|
37 |
+
tokenizer = tiktoken.get_encoding(tokenizer_name.name)
|
38 |
+
|
39 |
+
|
40 |
######################################################################################################################
|
41 |
# Some basic tools
|
42 |
######################################################################################################################
|
43 |
+
def remove_special_characters(s):
|
44 |
+
return ''.join(c for c in s if c.isalnum() or c.isspace() or c == ',')
|
45 |
+
|
46 |
+
|
47 |
+
def remove_newlines(serie):
|
48 |
+
# This function is applied to the abstract of each paper to reduce the length of prompts.
|
49 |
+
serie = serie.replace('\n', ' ')
|
50 |
+
serie = serie.replace('\\n', ' ')
|
51 |
+
serie = serie.replace(' ', ' ')
|
52 |
+
serie = serie.replace(' ', ' ')
|
53 |
+
return serie
|
54 |
+
|
55 |
+
|
56 |
def evaluate_cosine_similarity(v1, v2):
|
57 |
try:
|
58 |
+
return np.dot(v1, v2) / (norm(v1) * norm(v2))
|
59 |
except ValueError:
|
60 |
return 0.0
|
61 |
|
62 |
+
|
63 |
def chunks(lst, chunk_size=MAX_BATCH_SIZE):
|
64 |
"""Splits a longer list to respect batch size"""
|
65 |
for i in range(0, len(lst), chunk_size):
|
66 |
+
yield lst[i: i + chunk_size]
|
67 |
+
|
68 |
|
69 |
def embed(papers):
|
70 |
embeddings_by_paper_id: Dict[str, List[float]] = {}
|
|
|
80 |
|
81 |
return embeddings_by_paper_id
|
82 |
|
83 |
+
|
84 |
def get_embeddings(paper_title, paper_description):
|
85 |
output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
|
86 |
emb_vector = embed(output)["target_paper"]
|
|
|
88 |
target_paper["embeddings"] = emb_vector
|
89 |
return target_paper
|
90 |
|
91 |
+
|
92 |
+
def get_embeddings_vector(paper_title, paper_description):
|
93 |
+
output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
|
94 |
+
emb_vector = embed(output)["target_paper"]
|
95 |
+
return emb_vector
|
96 |
+
|
97 |
+
|
98 |
def get_top_k(papers_dict, paper_title, paper_description, k=None):
|
99 |
+
# returns the top k papers most similar to the target paper
|
100 |
target_paper = get_embeddings(paper_title, paper_description)
|
101 |
+
papers = papers_dict # must include embeddings
|
102 |
|
103 |
# if k < len(papers_json), return k most relevant papers
|
104 |
# if k >= len(papers_json) or k is None, return all papers
|
|
|
113 |
for k in papers:
|
114 |
v = papers[k]
|
115 |
embedding_vector = v["embeddings"]
|
116 |
+
cos_sim = evaluate_cosine_similarity(embedding_vector, target_embedding_vector)
|
117 |
papers[k]["cos_sim"] = cos_sim
|
118 |
|
119 |
# return the best k papers
|
|
|
122 |
sorted_papers[key].pop("embeddings", None)
|
123 |
return sorted_papers
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
def search_paper_abstract(title):
|
127 |
pg = ProxyGenerator()
|
|
|
140 |
return remove_newlines(found_paper['bib']['abstract'])
|
141 |
|
142 |
|
143 |
+
def tiktoken_len(text):
|
144 |
+
# evaluate how many tokens for the given text
|
145 |
+
tokens = tokenizer.encode(text, disallowed_special=())
|
146 |
+
return len(tokens)
|
147 |
+
|
148 |
+
|
149 |
+
######################################################################################################################
|
150 |
+
# Academic search tools
|
151 |
+
######################################################################################################################
|
152 |
+
def externalIds2link(externalIds):
|
153 |
+
# Sample externalIds:
|
154 |
+
# "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}"
|
155 |
+
if externalIds:
|
156 |
+
# Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI
|
157 |
+
# priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral)
|
158 |
+
# DBLP
|
159 |
+
dblp_id = externalIds.get('DBLP')
|
160 |
+
if dblp_id is not None:
|
161 |
+
dblp_link = f"dblp.org/rec/{dblp_id}"
|
162 |
+
return dblp_link
|
163 |
+
# arXiv
|
164 |
+
arxiv_id = externalIds.get('ArXiv')
|
165 |
+
if arxiv_id is not None:
|
166 |
+
arxiv_link = f"arxiv.org/abs/{arxiv_id}"
|
167 |
+
return arxiv_link
|
168 |
+
return ""
|
169 |
+
else:
|
170 |
+
# if this is an empty dictionary, return an empty string
|
171 |
+
return ""
|
172 |
+
|
173 |
+
|
174 |
+
def search_paper_arxiv(title):
|
175 |
+
search = arxiv.Search(
|
176 |
+
query=title,
|
177 |
+
max_results=1,
|
178 |
+
sort_by=arxiv.SortCriterion.Relevance
|
179 |
+
)
|
180 |
+
try:
|
181 |
+
# (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal (8) embeddings
|
182 |
+
result = next(search.results())
|
183 |
+
title = result.title
|
184 |
+
authors = " and ".join([author.name for author in result.authors])
|
185 |
+
year = str(result.updated.now().year)
|
186 |
+
link = result.pdf_url
|
187 |
+
abstract = result.summary
|
188 |
+
journal = f"Arxiv: {result.entry_id}"
|
189 |
+
paper_id = result.authors[0].name.replace(" ", "")[:4] + year + title[:6].replace(" ", "")
|
190 |
+
paper_id = paper_id.lower()
|
191 |
+
|
192 |
+
paper = {"paper_id": paper_id,
|
193 |
+
"title": title,
|
194 |
+
"authors": authors,
|
195 |
+
"year": year,
|
196 |
+
"link": link,
|
197 |
+
"abstract": abstract,
|
198 |
+
"journal": journal}
|
199 |
+
except StopIteration:
|
200 |
+
paper = {}
|
201 |
+
return paper
|
202 |
+
|
203 |
+
|
204 |
+
def search_paper_ss(title):
|
205 |
+
fields = ["title", "abstract", "venue", "year", "authors", "tldr", "externalIds"]
|
206 |
+
limit = 1
|
207 |
+
url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={title}&limit={limit}&fields={",".join(fields)}'
|
208 |
+
# headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
|
209 |
+
headers = {"Accept": "*/*"}
|
210 |
+
response = requests.get(url, headers=headers, timeout=30)
|
211 |
+
results = response.json()
|
212 |
+
if results['total'] == 0:
|
213 |
+
return {}
|
214 |
+
raw_paper = results['data'][0]
|
215 |
+
if raw_paper['tldr'] is not None:
|
216 |
+
abstract = raw_paper['tldr']['text']
|
217 |
+
elif raw_paper['abstract'] is not None:
|
218 |
+
abstract = remove_newlines(raw_paper['abstract'])
|
219 |
+
else:
|
220 |
+
abstract = ""
|
221 |
+
|
222 |
+
authors = [author['name'] for author in raw_paper['authors']]
|
223 |
+
authors_str = " and ".join(authors)
|
224 |
+
year_str = str(raw_paper['year'])
|
225 |
+
title = raw_paper['title']
|
226 |
+
|
227 |
+
paper_id = authors_str.replace(" ", "")[:4] + year_str + title[:6].replace(" ", "")
|
228 |
+
|
229 |
+
# some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
|
230 |
+
journal = remove_special_characters(raw_paper['venue'])
|
231 |
+
if not journal:
|
232 |
+
journal = "arXiv preprint"
|
233 |
+
link = externalIds2link(raw_paper['externalIds'])
|
234 |
+
paper = {
|
235 |
+
"paper_id": paper_id,
|
236 |
+
"title": title,
|
237 |
+
"abstract": abstract,
|
238 |
+
"link": link,
|
239 |
+
"authors": authors_str,
|
240 |
+
"year": year_str,
|
241 |
+
"journal": journal
|
242 |
+
}
|
243 |
+
return paper
|
244 |
+
|
245 |
+
|
246 |
+
def search_paper_scrape(title):
|
247 |
+
pg = ProxyGenerator()
|
248 |
+
success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
|
249 |
+
if success:
|
250 |
+
try:
|
251 |
+
scholarly.use_proxy(pg)
|
252 |
+
# input the title of a paper, return its abstract
|
253 |
+
search_query = scholarly.search_pubs(title)
|
254 |
+
found_paper = next(search_query)
|
255 |
+
url = found_paper['pub_url']
|
256 |
+
|
257 |
+
result = found_paper['bib']
|
258 |
+
|
259 |
+
title = result['title']
|
260 |
+
authors = " and ".join(result['author'])
|
261 |
+
year = str(result['pub_year'])
|
262 |
+
journal = result['pub_year']
|
263 |
+
abstract = result['abstract']
|
264 |
+
|
265 |
+
paper_id = authors.replace(" ", "")[:4] + year + title[:6].replace(" ", "")
|
266 |
+
paper = {
|
267 |
+
"paper_id": paper_id,
|
268 |
+
"title": title,
|
269 |
+
"abstract": abstract,
|
270 |
+
"link": url,
|
271 |
+
"authors": authors,
|
272 |
+
"year": year,
|
273 |
+
"journal": journal
|
274 |
+
}
|
275 |
+
return paper
|
276 |
+
except StopIteration:
|
277 |
+
return {}
|
278 |
+
|
279 |
+
|
280 |
+
def search_paper(title, verbose=True):
|
281 |
+
if verbose:
|
282 |
+
print(f"Searching {title}...")
|
283 |
+
# try Semantic Scholar first
|
284 |
+
paper = search_paper_ss(title)
|
285 |
+
if not paper:
|
286 |
+
paper = search_paper_arxiv(title)
|
287 |
+
if not paper:
|
288 |
+
paper = search_paper_scrape(title)
|
289 |
+
if paper:
|
290 |
+
paper["embeddings"] = get_embeddings_vector(paper_title=paper['title'], paper_description=paper['abstract'])
|
291 |
+
if verbose:
|
292 |
+
print(f"Search result: {paper}.")
|
293 |
+
return paper
|
294 |
+
|
295 |
+
|
296 |
def load_papers_from_bibtex(bib_file_path):
|
297 |
with open(bib_file_path) as bibtex_file:
|
298 |
bib_database = bibtexparser.load(bibtex_file)
|
|
|
324 |
bib_papers.append(result)
|
325 |
return bib_papers
|
326 |
|
|
|
|
|
|
|
|
|
327 |
|
328 |
+
def load_papers_from_text(text):
|
329 |
+
# split text by comma
|
330 |
+
titles = [part.strip() for part in text.split(',')]
|
331 |
+
titles = [remove_special_characters(title) for title in titles]
|
332 |
+
papers = []
|
333 |
+
if len(titles) > 0:
|
334 |
+
for title in titles:
|
335 |
+
paper = search_paper(title)
|
336 |
+
if paper:
|
337 |
+
papers.append(paper)
|
338 |
+
return papers
|
339 |
+
else:
|
340 |
+
return []
|
341 |
|
342 |
|
343 |
######################################################################################################################
|
|
|
349 |
fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"]
|
350 |
keywords = keywords.lower()
|
351 |
keywords = keywords.replace(" ", "+")
|
352 |
+
url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)} '
|
353 |
# headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
|
354 |
headers = {"Accept": "*/*"}
|
355 |
|
|
|
358 |
|
359 |
|
360 |
def _collect_papers_ss(keyword, counts=3, tldr=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
def extract_paper_id(last_name, year_str, title):
|
362 |
pattern = r'^\w+'
|
363 |
words = re.findall(pattern, title)
|
|
|
443 |
######################################################################################################################
|
444 |
|
445 |
class References:
|
446 |
+
def __init__(self,
|
447 |
+
title: str,
|
448 |
+
load_papers: Optional[str] = None,
|
449 |
+
load_bibtex: Optional[str] = None,
|
450 |
+
description: str = ""
|
451 |
+
):
|
452 |
+
self.papers = {}
|
453 |
+
if load_bibtex is not None:
|
454 |
+
self.papers["load_from_bibtex"] = load_papers_from_bibtex(load_bibtex)
|
455 |
if load_papers is not None:
|
456 |
+
self.papers["load_from_text"] = load_papers_from_text(load_papers)
|
457 |
+
|
|
|
458 |
self.title = title
|
459 |
self.description = description
|
460 |
|
461 |
+
def generate_keywords_dict(self) -> Dict[str, int]:
|
|
|
|
|
|
|
462 |
keywords_dict = {}
|
463 |
for k in self.papers:
|
464 |
keywords_dict[k] = len(self.papers[k])
|
465 |
return keywords_dict
|
466 |
|
467 |
+
def collect_papers(self, keywords_dict: Dict[str, int], tldr: bool = False) -> None:
|
468 |
"""
|
469 |
Collect as many papers as possible
|
470 |
|
|
|
478 |
keywords.append(" ".join(comb_keyword))
|
479 |
for key in keywords:
|
480 |
self.papers[key] = _collect_papers_ss(key, 10, tldr)
|
|
|
|
|
|
|
481 |
|
482 |
+
def to_bibtex(self, path_to_bibtex: str = "ref.bib") -> List[str]:
|
483 |
"""
|
484 |
Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`.
|
485 |
"""
|
|
|
|
|
|
|
486 |
papers = self._get_papers(keyword="_all")
|
487 |
|
488 |
+
num_papers = len(papers)
|
489 |
+
print(f"{num_papers} papers will be added to `ref.bib`.")
|
490 |
# clear the bibtex file
|
491 |
with open(path_to_bibtex, "w", encoding="utf-8") as file:
|
492 |
file.write("")
|
|
|
524 |
papers = self.papers["keyword"]
|
525 |
return papers
|
526 |
|
527 |
+
def to_prompts(self, keyword: str = "_all", max_tokens: int = 2048):
|
528 |
# `prompts`:
|
529 |
# {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"}
|
530 |
# this will be used to instruct GPT model to cite the correct bibtex entry.
|
|
|
536 |
papers_json = self.to_json()
|
537 |
with open(json_path, "w") as f:
|
538 |
json.dump(papers_json, f)
|
|
|
539 |
try:
|
540 |
# Use external API to obtain the most relevant papers
|
541 |
title = self.title
|
542 |
description = self.description
|
543 |
result = get_top_k(papers_json, title, description)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
result = [item for key, item in result.items()]
|
545 |
except Exception as e:
|
546 |
print(f"Error occurs during calling external API: {e}\n")
|
|
|
559 |
break
|
560 |
return prompts
|
561 |
|
562 |
+
def to_json(self, keyword: str = "_all"):
|
563 |
papers = self._get_papers(keyword)
|
564 |
papers_json = {}
|
565 |
for paper in papers:
|
566 |
papers_json[paper["paper_id"]] = paper
|
567 |
return papers_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
worker.py
CHANGED
@@ -3,7 +3,7 @@ This script is only used for service-side host.
|
|
3 |
'''
|
4 |
import boto3
|
5 |
import os, time
|
6 |
-
from
|
7 |
from sqlalchemy import create_engine, Table, MetaData, update, select
|
8 |
from sqlalchemy.orm import sessionmaker
|
9 |
from sqlalchemy import inspect
|
|
|
3 |
'''
|
4 |
import boto3
|
5 |
import os, time
|
6 |
+
from wrapper import generator_wrapper
|
7 |
from sqlalchemy import create_engine, Table, MetaData, update, select
|
8 |
from sqlalchemy.orm import sessionmaker
|
9 |
from sqlalchemy import inspect
|
wrapper.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is used to wrap all generation methods together.
|
3 |
+
|
4 |
+
todo:
|
5 |
+
A worker keeps running on the server. Monitor the Amazon SQS. Once receive a new message, do the following:
|
6 |
+
Download the corresponding configuration files on S3.
|
7 |
+
Change Task status from Pending to Running.
|
8 |
+
Call `generator_wrapper` and wait for the outputs.
|
9 |
+
If `generator_wrapper` returns results:
|
10 |
+
evaluate the results; compile it; upload results to S3 ... Change Task status from Running to Completed.
|
11 |
+
If anything goes wrong, raise Error.
|
12 |
+
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
13 |
+
Change Task status from Running to Failed.
|
14 |
+
"""
|
15 |
+
from auto_generators import generate_draft
|
16 |
+
from utils.file_operations import make_archive
|
17 |
+
import yaml
|
18 |
+
import uuid
|
19 |
+
|
20 |
+
|
21 |
+
def remove_special_characters(s):
|
22 |
+
return ''.join(c for c in s if c.isalnum() or c.isspace() or c == ',')
|
23 |
+
|
24 |
+
|
25 |
+
def generator_wrapper(config):
|
26 |
+
if not isinstance(config, dict):
|
27 |
+
with open(config, "r") as file:
|
28 |
+
config = yaml.safe_load(file)
|
29 |
+
title = config["paper"]["title"]
|
30 |
+
generator = config["generator"]
|
31 |
+
if generator == "auto_draft":
|
32 |
+
folder = generate_draft(title, config["paper"]["description"],
|
33 |
+
tldr=config["references"]["tldr"],
|
34 |
+
max_kw_refs=config["references"]["max_kw_refs"],
|
35 |
+
refs=config["references"]["refs"],
|
36 |
+
max_tokens_ref=config["references"]["max_tokens_ref"],
|
37 |
+
knowledge_database=config["domain_knowledge"]["knowledge_database"],
|
38 |
+
max_tokens_kd=config["domain_knowledge"]["max_tokens_kd"],
|
39 |
+
query_counts=config["domain_knowledge"]["query_counts"],
|
40 |
+
sections=config["output"]["selected_sections"],
|
41 |
+
model=config["output"]["model"],
|
42 |
+
template=config["output"]["template"],
|
43 |
+
prompts_mode=config["output"]["prompts_mode"],
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
raise NotImplementedError(f"The generator {generator} has not been supported yet.")
|
47 |
+
# todo: post processing: translate to Chinese, compile PDF ...
|
48 |
+
filename = remove_special_characters(title).replace(" ", "_") + uuid.uuid1().hex + ".zip"
|
49 |
+
return make_archive(folder, filename)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
pass
|
54 |
+
# with open("configurations/default.yaml", 'r') as file:
|
55 |
+
# config = yaml.safe_load(file)
|
56 |
+
# print(config)
|
57 |
+
# generator_wrapper(config)
|