timeki commited on
Commit
76603df
·
1 Parent(s): a059c93

put event handling in separate file

Browse files
Files changed (3) hide show
  1. app.py +22 -131
  2. climateqa/event_handler.py +120 -0
  3. front/utils.py +42 -1
app.py CHANGED
@@ -27,12 +27,11 @@ from azure.storage.fileshare import ShareServiceClient
27
 
28
  from utils import create_user_id
29
 
30
- from langchain_chroma import Chroma
31
- from collections import defaultdict
32
  from gradio_modal import Modal
33
 
34
  from PIL import Image
35
 
 
36
 
37
  # ClimateQ&A imports
38
  from climateqa.engine.llm import get_llm
@@ -49,9 +48,9 @@ from climateqa.engine.keywords import make_keywords_chain
49
  from climateqa.engine.graph import make_graph_agent,display_graph
50
  from climateqa.engine.embeddings import get_embeddings_function
51
 
52
- from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
53
 
54
- from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox
55
 
56
  # Load environment variables in local mode
57
  try:
@@ -121,6 +120,7 @@ reranker = get_reranker("nano")
121
  # agent = make_graph_agent(llm,vectorstore,reranker)
122
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
123
 
 
124
  async def chat(query,history,audience,sources,reports,current_graphs):
125
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
126
  (messages in gradio format, messages in langchain format, source documents)"""
@@ -128,14 +128,7 @@ async def chat(query,history,audience,sources,reports,current_graphs):
128
  date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
129
  print(f">> NEW QUESTION ({date_now}) : {query}")
130
 
131
- if audience == "Children":
132
- audience_prompt = audience_prompts["children"]
133
- elif audience == "General public":
134
- audience_prompt = audience_prompts["general"]
135
- elif audience == "Experts":
136
- audience_prompt = audience_prompts["experts"]
137
- else:
138
- audience_prompt = audience_prompts["experts"]
139
 
140
  # Prepare default values
141
  if sources is None or len(sources) == 0:
@@ -149,14 +142,11 @@ async def chat(query,history,audience,sources,reports,current_graphs):
149
 
150
 
151
  docs = []
152
- docs_used = True
153
  used_figures=[]
154
  docs_html = ""
155
  output_query = ""
156
  output_language = ""
157
  output_keywords = ""
158
- gallery = []
159
- updates = []
160
  start_streaming = False
161
  graphs_html = ""
162
  figures = '<div class="figures-container"><p></p> </div>'
@@ -175,79 +165,19 @@ async def chat(query,history,audience,sources,reports,current_graphs):
175
  node = event["metadata"]["langgraph_node"]
176
 
177
  if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
178
- try:
179
- docs = event["data"]["output"]["documents"]
180
- docs_html = []
181
- textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
182
- for i, d in enumerate(textual_docs, 1):
183
- if d.metadata["chunk_type"] == "text":
184
- docs_html.append(make_html_source(d, i))
185
-
186
- used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
187
- history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
188
-
189
- docs_html = "".join(docs_html)
190
-
191
- except Exception as e:
192
- print(f"Error getting documents: {e}")
193
- print(event)
194
-
195
 
196
 
197
  elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
198
- event_description,display_output = steps_display[node]
199
  if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
200
  history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
201
 
202
  elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
203
- if start_streaming == False:
204
- start_streaming = True
205
- history.append(ChatMessage(role="assistant", content = ""))
206
- answer_message_content += event["data"]["chunk"].content
207
- answer_message_content = parse_output_llm_with_sources(answer_message_content)
208
- history[-1] = ChatMessage(role="assistant", content = answer_message_content)
209
- # history.append(ChatMessage(role="assistant", content = new_message_content))
210
 
211
  elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
212
- try:
213
- recommended_content = event["data"]["output"]["recommended_content"]
214
-
215
- unique_graphs = []
216
- seen_embeddings = set()
217
-
218
- for x in recommended_content:
219
- embedding = x.metadata["returned_content"]
220
-
221
- # Check if the embedding has already been seen
222
- if embedding not in seen_embeddings:
223
- unique_graphs.append({
224
- "embedding": embedding,
225
- "metadata": {
226
- "source": x.metadata["source"],
227
- "category": x.metadata["category"]
228
- }
229
- })
230
- # Add the embedding to the seen set
231
- seen_embeddings.add(embedding)
232
-
233
-
234
- categories = {}
235
- for graph in unique_graphs:
236
- category = graph['metadata']['category']
237
- if category not in categories:
238
- categories[category] = []
239
- categories[category].append(graph['embedding'])
240
-
241
-
242
- for category, embeddings in categories.items():
243
- graphs_html += f"<h3>{category}</h3>"
244
- for embedding in embeddings:
245
- graphs_html += f"<div>{embedding}</div>"
246
-
247
-
248
- except Exception as e:
249
- print(f"Error getting graphs: {e}")
250
-
251
 
252
 
253
  if event["name"] == "transform_query" and event["event"] =="on_chain_end":
@@ -257,7 +187,7 @@ async def chat(query,history,audience,sources,reports,current_graphs):
257
  if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
258
  print("X")
259
 
260
- yield history, docs_html, output_query, output_language, docs , graphs_html#gallery, figures, #,output_query,output_keywords
261
 
262
  except Exception as e:
263
  print(event, "has failed")
@@ -285,52 +215,9 @@ async def chat(query,history,audience,sources,reports,current_graphs):
285
  print(f"Error logging on Azure Blob Storage: {e}")
286
  raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
287
 
288
-
289
 
290
 
291
-
292
-
293
-
294
-
295
- yield history, docs_html, output_query, output_language, docs, graphs_html # gallery, figures, graphs_html#,output_query,output_keywords
296
-
297
- # def process_figures(docs, figures, gallery, used_figures =[]):
298
- def process_figures(docs):
299
- gallery=[]
300
- used_figures =[]
301
- figures = '<div class="figures-container"><p></p> </div>'
302
- docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
303
- for i, doc in enumerate(docs_figures):
304
- if doc.metadata["chunk_type"] == "image":
305
- if doc.metadata["figure_code"] != "N/A":
306
- title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
307
- else:
308
- title = f"{doc.metadata['short_name']}"
309
-
310
-
311
- if title not in used_figures:
312
- used_figures.append(title)
313
- try:
314
- key = f"Image {i+1}"
315
-
316
- image_path = doc.metadata["image_path"].split("documents/")[1]
317
- img = get_image_from_azure_blob_storage(image_path)
318
-
319
- # Convert the image to a byte buffer
320
- buffered = BytesIO()
321
- max_image_length = 500
322
- img_resized = img.resize((max_image_length, int(max_image_length * img.size[1]/img.size[0])))
323
- img_resized.save(buffered, format="PNG")
324
-
325
- img_str = base64.b64encode(buffered.getvalue()).decode()
326
-
327
- figures = figures + make_html_figure_sources(doc, i, img_str)
328
- gallery.append(img)
329
- except Exception as e:
330
- print(f"Skipped adding image {i} because of {e}")
331
-
332
- return figures, gallery
333
-
334
  def save_feedback(feed: str, user_id):
335
  if len(feed) > 1:
336
  timestamp = str(datetime.now().timestamp())
@@ -657,13 +544,15 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
657
 
658
 
659
 
660
- gr.Markdown("""
661
- ### More info
662
- - See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
663
- - Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
664
-
665
- ### Citation
666
- """)
 
 
667
  with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
668
  # # Display citation label and text)
669
  gr.Textbox(
@@ -721,6 +610,8 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
721
 
722
 
723
  sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
 
 
724
  sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content])
725
  figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content])
726
  current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content])
 
27
 
28
  from utils import create_user_id
29
 
 
 
30
  from gradio_modal import Modal
31
 
32
  from PIL import Image
33
 
34
+ from langchain_core.runnables.schema import StreamEvent
35
 
36
  # ClimateQ&A imports
37
  from climateqa.engine.llm import get_llm
 
48
  from climateqa.engine.graph import make_graph_agent,display_graph
49
  from climateqa.engine.embeddings import get_embeddings_function
50
 
51
+ from front.utils import serialize_docs,process_figures
52
 
53
+ from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
54
 
55
  # Load environment variables in local mode
56
  try:
 
120
  # agent = make_graph_agent(llm,vectorstore,reranker)
121
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
122
 
123
+
124
  async def chat(query,history,audience,sources,reports,current_graphs):
125
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
126
  (messages in gradio format, messages in langchain format, source documents)"""
 
128
  date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
129
  print(f">> NEW QUESTION ({date_now}) : {query}")
130
 
131
+ audience_prompt = init_audience(audience)
 
 
 
 
 
 
 
132
 
133
  # Prepare default values
134
  if sources is None or len(sources) == 0:
 
142
 
143
 
144
  docs = []
 
145
  used_figures=[]
146
  docs_html = ""
147
  output_query = ""
148
  output_language = ""
149
  output_keywords = ""
 
 
150
  start_streaming = False
151
  graphs_html = ""
152
  figures = '<div class="figures-container"><p></p> </div>'
 
165
  node = event["metadata"]["langgraph_node"]
166
 
167
  if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
168
+ docs, docs_html, history, used_documents = handle_retrieved_documents(event, history, used_documents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
172
+ event_description, display_output = steps_display[node]
173
  if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
174
  history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
175
 
176
  elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
177
+ history, start_streaming, answer_message_content = stream_answer(history, event, start_streaming, answer_message_content)
 
 
 
 
 
 
178
 
179
  elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
180
+ graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
 
183
  if event["name"] == "transform_query" and event["event"] =="on_chain_end":
 
187
  if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
188
  print("X")
189
 
190
+ yield history, docs_html, output_query, output_language, docs , graphs_html #,output_query,output_keywords
191
 
192
  except Exception as e:
193
  print(event, "has failed")
 
215
  print(f"Error logging on Azure Blob Storage: {e}")
216
  raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
217
 
218
+ yield history, docs_html, output_query, output_language, docs, graphs_html
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def save_feedback(feed: str, user_id):
222
  if len(feed) > 1:
223
  timestamp = str(datetime.now().timestamp())
 
544
 
545
 
546
 
547
+ gr.Markdown(
548
+ """
549
+ ### More info
550
+ - See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
551
+ - Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
552
+
553
+ ### Citation
554
+ """
555
+ )
556
  with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
557
  # # Display citation label and text)
558
  gr.Textbox(
 
610
 
611
 
612
  sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
613
+
614
+
615
  sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content])
616
  figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content])
617
  current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content])
climateqa/event_handler.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables.schema import StreamEvent
2
+ from gradio import ChatMessage
3
+ from climateqa.engine.chains.prompts import audience_prompts
4
+ from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
5
+ import numpy as np
6
+
7
+ def init_audience(audience :str) -> str:
8
+ if audience == "Children":
9
+ audience_prompt = audience_prompts["children"]
10
+ elif audience == "General public":
11
+ audience_prompt = audience_prompts["general"]
12
+ elif audience == "Experts":
13
+ audience_prompt = audience_prompts["experts"]
14
+ else:
15
+ audience_prompt = audience_prompts["experts"]
16
+ return audience_prompt
17
+
18
+ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
19
+ """
20
+ Handles the retrieved documents and returns the HTML representation of the documents
21
+
22
+ Args:
23
+ event (StreamEvent): The event containing the retrieved documents
24
+ history (list[ChatMessage]): The current message history
25
+ used_documents (list[str]): The list of used documents
26
+
27
+ Returns:
28
+ tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
29
+ """
30
+ try:
31
+ docs = event["data"]["output"]["documents"]
32
+ docs_html = []
33
+ textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
34
+ for i, d in enumerate(textual_docs, 1):
35
+ if d.metadata["chunk_type"] == "text":
36
+ docs_html.append(make_html_source(d, i))
37
+
38
+ used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
39
+ history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
40
+
41
+ docs_html = "".join(docs_html)
42
+
43
+ except Exception as e:
44
+ print(f"Error getting documents: {e}")
45
+ print(event)
46
+ return docs, docs_html, history, used_documents
47
+
48
+ def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
49
+ """
50
+ Handles the streaming of the answer and updates the history with the new message content
51
+
52
+ Args:
53
+ history (list[ChatMessage]): The current message history
54
+ event (StreamEvent): The event containing the streamed answer
55
+ start_streaming (bool): A flag indicating if the streaming has started
56
+ new_message_content (str): The content of the new message
57
+
58
+ Returns:
59
+ tuple[list[ChatMessage], bool, str]: The updated history, the updated streaming flag and the updated message content
60
+ """
61
+ if start_streaming == False:
62
+ start_streaming = True
63
+ history.append(ChatMessage(role="assistant", content = ""))
64
+ answer_message_content += event["data"]["chunk"].content
65
+ answer_message_content = parse_output_llm_with_sources(answer_message_content)
66
+ history[-1] = ChatMessage(role="assistant", content = answer_message_content)
67
+ # history.append(ChatMessage(role="assistant", content = new_message_content))
68
+ return history, start_streaming, answer_message_content
69
+
70
+ def handle_retrieved_owid_graphs(event :StreamEvent, graphs_html: str) -> str:
71
+ """
72
+ Handles the retrieved OWID graphs and returns the HTML representation of the graphs
73
+
74
+ Args:
75
+ event (StreamEvent): The event containing the retrieved graphs
76
+ graphs_html (str): The current HTML representation of the graphs
77
+
78
+ Returns:
79
+ str: The updated HTML representation
80
+ """
81
+ try:
82
+ recommended_content = event["data"]["output"]["recommended_content"]
83
+
84
+ unique_graphs = []
85
+ seen_embeddings = set()
86
+
87
+ for x in recommended_content:
88
+ embedding = x.metadata["returned_content"]
89
+
90
+ # Check if the embedding has already been seen
91
+ if embedding not in seen_embeddings:
92
+ unique_graphs.append({
93
+ "embedding": embedding,
94
+ "metadata": {
95
+ "source": x.metadata["source"],
96
+ "category": x.metadata["category"]
97
+ }
98
+ })
99
+ # Add the embedding to the seen set
100
+ seen_embeddings.add(embedding)
101
+
102
+
103
+ categories = {}
104
+ for graph in unique_graphs:
105
+ category = graph['metadata']['category']
106
+ if category not in categories:
107
+ categories[category] = []
108
+ categories[category].append(graph['embedding'])
109
+
110
+
111
+ for category, embeddings in categories.items():
112
+ graphs_html += f"<h3>{category}</h3>"
113
+ for embedding in embeddings:
114
+ graphs_html += f"<div>{embedding}</div>"
115
+
116
+
117
+ except Exception as e:
118
+ print(f"Error getting graphs: {e}")
119
+
120
+ return graphs_html
front/utils.py CHANGED
@@ -1,5 +1,12 @@
1
 
2
  import re
 
 
 
 
 
 
 
3
 
4
  def make_pairs(lst):
5
  """from a list of even lenght, make tupple pairs"""
@@ -32,8 +39,42 @@ def parse_output_llm_with_sources(output):
32
  content_parts = "".join(parts)
33
  return content_parts
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- from collections import defaultdict
37
 
38
  def generate_html_graphs(graphs):
39
  # Organize graphs by category
 
1
 
2
  import re
3
+ from collections import defaultdict
4
+ from climateqa.utils import get_image_from_azure_blob_storage
5
+ from climateqa.engine.chains.prompts import audience_prompts
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import base64
9
+
10
 
11
  def make_pairs(lst):
12
  """from a list of even lenght, make tupple pairs"""
 
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
42
+ def process_figures(docs):
43
+ gallery=[]
44
+ used_figures =[]
45
+ figures = '<div class="figures-container"><p></p> </div>'
46
+ docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
47
+ for i, doc in enumerate(docs_figures):
48
+ if doc.metadata["chunk_type"] == "image":
49
+ if doc.metadata["figure_code"] != "N/A":
50
+ title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
51
+ else:
52
+ title = f"{doc.metadata['short_name']}"
53
+
54
+
55
+ if title not in used_figures:
56
+ used_figures.append(title)
57
+ try:
58
+ key = f"Image {i+1}"
59
+
60
+ image_path = doc.metadata["image_path"].split("documents/")[1]
61
+ img = get_image_from_azure_blob_storage(image_path)
62
+
63
+ # Convert the image to a byte buffer
64
+ buffered = BytesIO()
65
+ max_image_length = 500
66
+ img_resized = img.resize((max_image_length, int(max_image_length * img.size[1]/img.size[0])))
67
+ img_resized.save(buffered, format="PNG")
68
+
69
+ img_str = base64.b64encode(buffered.getvalue()).decode()
70
+
71
+ figures = figures + make_html_figure_sources(doc, i, img_str)
72
+ gallery.append(img)
73
+ except Exception as e:
74
+ print(f"Skipped adding image {i} because of {e}")
75
+
76
+ return figures, gallery
77
 
 
78
 
79
  def generate_html_graphs(graphs):
80
  # Organize graphs by category