timeki commited on
Commit
6541df3
·
2 Parent(s): 9609df9 c3b815e

Merge branch 'add_openalex_papers' into feature/graph_recommandation

Browse files
app.py CHANGED
@@ -8,6 +8,7 @@ from sentence_transformers import CrossEncoder
8
  oa = OpenAlex()
9
 
10
  import gradio as gr
 
11
  import pandas as pd
12
  import numpy as np
13
  import os
@@ -44,11 +45,11 @@ from climateqa.sample_questions import QUESTIONS
44
  from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
45
  from climateqa.utils import get_image_from_azure_blob_storage
46
  from climateqa.engine.keywords import make_keywords_chain
47
- # from climateqa.engine.chains.answer_rag import make_rag_papers_chain
48
- from climateqa.engine.graph import make_graph_agent,display_graph
49
  from climateqa.engine.embeddings import get_embeddings_function
50
 
51
- from front.utils import serialize_docs,process_figures
52
 
53
  from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
54
 
@@ -115,9 +116,7 @@ vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name =
115
 
116
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
117
  reranker = get_reranker("nano")
118
- # agent = make_graph_agent(llm,vectorstore,reranker)
119
 
120
- # agent = make_graph_agent(llm,vectorstore,reranker)
121
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
122
 
123
 
@@ -248,13 +247,11 @@ def generate_keywords(query):
248
 
249
 
250
  papers_cols_widths = {
251
- "doc":50,
252
  "id":100,
253
  "title":300,
254
  "doi":100,
255
  "publication_year":100,
256
  "abstract":500,
257
- "rerank_score":100,
258
  "is_oa":50,
259
  }
260
 
@@ -262,6 +259,62 @@ papers_cols = list(papers_cols_widths.keys())
262
  papers_cols_widths = list(papers_cols_widths.values())
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  # --------------------------------------------------------------------
266
  # Gradio
267
  # --------------------------------------------------------------------
@@ -363,7 +416,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
363
  samples.append(group_examples)
364
 
365
 
366
- with gr.Tab("Sources",elem_id = "tab-citations",id = 1) as tab_sources:
367
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
368
  docs_textbox = gr.State("")
369
 
@@ -379,7 +432,28 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
379
  show_full_size_figures.click(lambda : Modal(visible=True),None,modal)
380
 
381
  figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
382
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=4) as tab_recommended_content:
385
  graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
@@ -511,6 +585,38 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
511
  # with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
512
  # gallery_component = gr.Gallery(object_fit='cover')
513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  # with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
515
 
516
  # with gr.Row():
@@ -571,6 +677,21 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
571
 
572
  def finish_chat():
573
  return gr.update(interactive = True,value = "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
 
576
  def change_completion_status(current_state):
@@ -618,6 +739,11 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
618
 
619
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
620
 
 
 
 
 
 
621
 
622
  demo.queue()
623
 
 
8
  oa = OpenAlex()
9
 
10
  import gradio as gr
11
+ from gradio_modal import Modal
12
  import pandas as pd
13
  import numpy as np
14
  import os
 
45
  from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
46
  from climateqa.utils import get_image_from_azure_blob_storage
47
  from climateqa.engine.keywords import make_keywords_chain
48
+ from climateqa.engine.chains.answer_rag import make_rag_papers_chain
49
+ from climateqa.engine.graph import make_graph_agent
50
  from climateqa.engine.embeddings import get_embeddings_function
51
 
52
+ from front.utils import serialize_docs,process_figures,make_html_df
53
 
54
  from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
55
 
 
116
 
117
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
118
  reranker = get_reranker("nano")
 
119
 
 
120
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
121
 
122
 
 
247
 
248
 
249
  papers_cols_widths = {
 
250
  "id":100,
251
  "title":300,
252
  "doi":100,
253
  "publication_year":100,
254
  "abstract":500,
 
255
  "is_oa":50,
256
  }
257
 
 
259
  papers_cols_widths = list(papers_cols_widths.values())
260
 
261
 
262
+ async def find_papers(query,after):
263
+
264
+ summary = ""
265
+ keywords = generate_keywords(query)
266
+ df_works = oa.search(keywords,after = after)
267
+ df_works = df_works.dropna(subset=["abstract"])
268
+ df_works = oa.rerank(query,df_works,reranker)
269
+ df_works = df_works.sort_values("rerank_score",ascending=False)
270
+ docs_html = []
271
+ for i in range(10):
272
+ docs_html.append(make_html_df(df_works, i))
273
+ docs_html = "".join(docs_html)
274
+ print(docs_html)
275
+ G = oa.make_network(df_works)
276
+
277
+ height = "750px"
278
+ network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
279
+ network_html = network.generate_html()
280
+
281
+ network_html = network_html.replace("'", "\"")
282
+ css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
283
+ network_html = network_html + css_to_inject
284
+
285
+
286
+ network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
287
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
288
+ allow-scripts allow-same-origin allow-popups
289
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
290
+ allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
291
+
292
+
293
+ docs = df_works["content"].head(10).tolist()
294
+
295
+ df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
296
+ df_works["doc"] = df_works["doc"] + 1
297
+ df_works = df_works[papers_cols]
298
+
299
+ yield docs_html, network_html, summary
300
+
301
+ chain = make_rag_papers_chain(llm)
302
+ result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
303
+ path_answer = "/logs/StrOutputParser/streamed_output/-"
304
+
305
+ async for op in result:
306
+
307
+ op = op.ops[0]
308
+
309
+ if op['path'] == path_answer: # reforulated question
310
+ new_token = op['value'] # str
311
+ summary += new_token
312
+ else:
313
+ continue
314
+ yield docs_html, network_html, summary
315
+
316
+
317
+
318
  # --------------------------------------------------------------------
319
  # Gradio
320
  # --------------------------------------------------------------------
 
416
  samples.append(group_examples)
417
 
418
 
419
+ with gr.Tab("Sources",elem_id = "tab-sources",id = 1):
420
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
421
  docs_textbox = gr.State("")
422
 
 
432
  show_full_size_figures.click(lambda : Modal(visible=True),None,modal)
433
 
434
  figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
435
+
436
+
437
+
438
+ with gr.Tab("Papers",elem_id = "tab-citations",id = 5):
439
+ btn_summary = gr.Button("Summary")
440
+ # Fenêtre simulée pour le Summary
441
+ with gr.Group(visible=False, elem_id="papers-summary-popup") as summary_popup:
442
+ papers_summary = gr.Markdown("### Summary Content", visible=True, elem_id="papers-summary")
443
+
444
+ btn_relevant_papers = gr.Button("Relevant papers")
445
+ # Fenêtre simulée pour les Relevant Papers
446
+ with gr.Group(visible=False, elem_id="papers-relevant-popup") as relevant_popup:
447
+ papers_html = gr.HTML(show_label=False, elem_id="sources-textbox")
448
+ docs_textbox = gr.State("")
449
+
450
+ btn_citations_network = gr.Button("Citations network")
451
+ # Fenêtre simulée pour le Citations Network
452
+ with Modal(visible=False) as modal:
453
+ citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
454
+ btn_citations_network.click(lambda: Modal(visible=True), None, modal)
455
+
456
+
457
 
458
  with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=4) as tab_recommended_content:
459
  graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
 
585
  # with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
586
  # gallery_component = gr.Gallery(object_fit='cover')
587
 
588
+ with gr.Tab("Settings",elem_id = "tab-config",id = 2):
589
+
590
+ gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
591
+
592
+
593
+ dropdown_sources = gr.CheckboxGroup(
594
+ ["IPCC", "IPBES","IPOS", "OpenAlex"],
595
+ label="Select source",
596
+ value=["IPCC"],
597
+ interactive=True,
598
+ )
599
+
600
+ dropdown_reports = gr.Dropdown(
601
+ POSSIBLE_REPORTS,
602
+ label="Or select specific reports",
603
+ multiselect=True,
604
+ value=None,
605
+ interactive=True,
606
+ )
607
+
608
+ dropdown_audience = gr.Dropdown(
609
+ ["Children","General public","Experts"],
610
+ label="Select audience",
611
+ value="Experts",
612
+ interactive=True,
613
+ )
614
+
615
+ after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
616
+
617
+ output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
618
+ output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
619
+
620
  # with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
621
 
622
  # with gr.Row():
 
677
 
678
  def finish_chat():
679
  return gr.update(interactive = True,value = "")
680
+
681
+ # Initialize visibility states
682
+ summary_visible = False
683
+ relevant_visible = False
684
+
685
+ # Functions to toggle visibility
686
+ def toggle_summary_visibility():
687
+ global summary_visible
688
+ summary_visible = not summary_visible
689
+ return gr.update(visible=summary_visible)
690
+
691
+ def toggle_relevant_visibility():
692
+ global relevant_visible
693
+ relevant_visible = not relevant_visible
694
+ return gr.update(visible=relevant_visible)
695
 
696
 
697
  def change_completion_status(current_state):
 
739
 
740
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
741
 
742
+ textbox.submit(find_papers,[textbox,after], [papers_html,citations_network,papers_summary])
743
+ examples_hidden.change(find_papers,[examples_hidden,after], [papers_html,citations_network,papers_summary])
744
+
745
+ btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
746
+ btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
747
 
748
  demo.queue()
749
 
climateqa/engine/chains/answer_rag.py CHANGED
@@ -7,6 +7,8 @@ from langchain_core.prompts.base import format_document
7
 
8
  from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
  from climateqa.engine.chains.prompts import papers_prompt_template
 
 
10
 
11
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
12
 
@@ -71,32 +73,32 @@ def make_rag_node(llm,with_docs = True):
71
 
72
 
73
 
74
- # def make_rag_papers_chain(llm):
75
 
76
- # prompt = ChatPromptTemplate.from_template(papers_prompt_template)
77
- # input_documents = {
78
- # "context":lambda x : _combine_documents(x["docs"]),
79
- # **pass_values(["question","language"])
80
- # }
81
 
82
- # chain = input_documents | prompt | llm | StrOutputParser()
83
- # chain = rename_chain(chain,"answer")
84
 
85
- # return chain
86
 
87
 
88
 
89
 
90
 
91
 
92
- # def make_illustration_chain(llm):
93
 
94
- # prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
95
 
96
- # input_description_images = {
97
- # "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
98
- # **pass_values(["question","audience","language","answer"]),
99
- # }
100
 
101
- # illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
102
- # return illustration_chain
 
7
 
8
  from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
  from climateqa.engine.chains.prompts import papers_prompt_template
10
+ from ..utils import rename_chain, pass_values
11
+
12
 
13
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
14
 
 
73
 
74
 
75
 
76
+ def make_rag_papers_chain(llm):
77
 
78
+ prompt = ChatPromptTemplate.from_template(papers_prompt_template)
79
+ input_documents = {
80
+ "context":lambda x : _combine_documents(x["docs"]),
81
+ **pass_values(["question","language"])
82
+ }
83
 
84
+ chain = input_documents | prompt | llm | StrOutputParser()
85
+ chain = rename_chain(chain,"answer")
86
 
87
+ return chain
88
 
89
 
90
 
91
 
92
 
93
 
94
+ def make_illustration_chain(llm):
95
 
96
+ prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
97
 
98
+ input_description_images = {
99
+ "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
100
+ **pass_values(["question","audience","language","answer"]),
101
+ }
102
 
103
+ illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
104
+ return illustration_chain
climateqa/knowledge/openalex.py CHANGED
@@ -62,11 +62,10 @@ class OpenAlex():
62
 
63
  scores = reranker.rank(
64
  query,
65
- df["content"].tolist(),
66
- top_k = len(df),
67
  )
68
- scores.sort(key = lambda x : x["corpus_id"])
69
- scores = [x["score"] for x in scores]
70
  df["rerank_score"] = scores
71
  return df
72
 
 
62
 
63
  scores = reranker.rank(
64
  query,
65
+ df["content"].tolist()
 
66
  )
67
+ scores = sorted(scores.results, key = lambda x : x.document.doc_id)
68
+ scores = [x.score for x in scores]
69
  df["rerank_score"] = scores
70
  return df
71
 
front/utils.py CHANGED
@@ -228,6 +228,28 @@ def make_html_source(source,i):
228
  return card
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def make_html_figure_sources(source,i,img_str):
232
  meta = source.metadata
233
  content = source.page_content.strip()
 
228
  return card
229
 
230
 
231
+ def make_html_df(df,i):
232
+ title = df['title'][i]
233
+ content = df['abstract'][i]
234
+ url = df['doi'][i]
235
+ publication_date = df['publication_year'][i]
236
+
237
+ card = f"""
238
+ <div class="card" id="doc{i}">
239
+ <div class="card-content">
240
+ <h2>Doc {i+1} - {title}</h2>
241
+ <p>{content}</p>
242
+ </div>
243
+ <div class="card-footer">
244
+ <span>{publication_date}</span>
245
+ <a href="{url}" target="_blank" class="pdf-link">
246
+ </div>
247
+ </div>
248
+ """
249
+
250
+ return card
251
+
252
+
253
  def make_html_figure_sources(source,i,img_str):
254
  meta = source.metadata
255
  content = source.page_content.strip()