Izza-shahzad-13 commited on
Commit
e00a19f
·
verified ·
1 Parent(s): 7b48254

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -13
app.py CHANGED
@@ -20,8 +20,21 @@ def retrieve_embedding(user_query):
20
  headers = {
21
  "Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
22
  }
 
 
23
  response = requests.post(f"{GROQ_API_URL}/embedding", json=payload, headers=headers)
24
- return response.json()["embedding"]
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Function to perform response generation using FLAN-T5 via Groq API
27
  def generate_response(context):
@@ -32,8 +45,21 @@ def generate_response(context):
32
  headers = {
33
  "Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
34
  }
 
 
35
  response = requests.post(f"{GROQ_API_URL}/generate", json=payload, headers=headers)
36
- return response.json()["text"]
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Load the counseling conversations dataset
39
  dataset = load_dataset("Amod/mental_health_counseling_conversations")["train"]
@@ -43,8 +69,9 @@ dataset = load_dataset("Amod/mental_health_counseling_conversations")["train"]
43
  def embed_dataset(_dataset):
44
  embeddings = []
45
  for entry in _dataset:
46
- embedding = retrieve_embedding(entry["Response"])
47
- embeddings.append(embedding)
 
48
  return embeddings
49
 
50
  dataset_embeddings = embed_dataset(dataset)
@@ -52,12 +79,16 @@ dataset_embeddings = embed_dataset(dataset)
52
  # Function to retrieve closest responses from the dataset using cosine similarity
53
  def retrieve_response(user_query, dataset, dataset_embeddings, k=5):
54
  query_embedding = retrieve_embedding(user_query)
 
 
 
 
55
  cos_scores = cosine_similarity([query_embedding], dataset_embeddings)[0]
56
  top_indices = np.argsort(cos_scores)[-k:][::-1]
57
 
58
  retrieved_responses = []
59
  for idx in top_indices:
60
- retrieved_responses.append(dataset[idx]["Response"])
61
  return retrieved_responses
62
 
63
  # Streamlit app UI
@@ -71,11 +102,17 @@ if user_query:
71
  # Retrieve similar responses from the dataset
72
  retrieved_responses = retrieve_response(user_query, dataset, dataset_embeddings)
73
 
74
- # Join retrieved responses to create a supportive context
75
- context = " ".join(retrieved_responses)
76
-
77
- # Generate a supportive response using FLAN-T5 via Groq API
78
- supportive_response = generate_response(context)
79
-
80
- st.write("Here's some advice or support for you:")
81
- st.write(supportive_response)
 
 
 
 
 
 
 
20
  headers = {
21
  "Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
22
  }
23
+
24
+ # Make the API request
25
  response = requests.post(f"{GROQ_API_URL}/embedding", json=payload, headers=headers)
26
+
27
+ # Check for errors and return the embedding if available
28
+ if response.status_code == 200:
29
+ json_response = response.json()
30
+ if "embedding" in json_response:
31
+ return json_response["embedding"]
32
+ else:
33
+ st.error("The response from the API did not contain an embedding. Please check the API.")
34
+ return None
35
+ else:
36
+ st.error(f"Failed to retrieve embedding. Status code: {response.status_code}")
37
+ return None
38
 
39
  # Function to perform response generation using FLAN-T5 via Groq API
40
  def generate_response(context):
 
45
  headers = {
46
  "Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
47
  }
48
+
49
+ # Make the API request
50
  response = requests.post(f"{GROQ_API_URL}/generate", json=payload, headers=headers)
51
+
52
+ # Check for errors and return the response text if available
53
+ if response.status_code == 200:
54
+ json_response = response.json()
55
+ if "text" in json_response:
56
+ return json_response["text"]
57
+ else:
58
+ st.error("The response from the API did not contain a 'text' key.")
59
+ return None
60
+ else:
61
+ st.error(f"Failed to generate response. Status code: {response.status_code}")
62
+ return None
63
 
64
  # Load the counseling conversations dataset
65
  dataset = load_dataset("Amod/mental_health_counseling_conversations")["train"]
 
69
  def embed_dataset(_dataset):
70
  embeddings = []
71
  for entry in _dataset:
72
+ embedding = retrieve_embedding(entry["response"])
73
+ if embedding is not None:
74
+ embeddings.append(embedding)
75
  return embeddings
76
 
77
  dataset_embeddings = embed_dataset(dataset)
 
79
  # Function to retrieve closest responses from the dataset using cosine similarity
80
  def retrieve_response(user_query, dataset, dataset_embeddings, k=5):
81
  query_embedding = retrieve_embedding(user_query)
82
+ if query_embedding is None:
83
+ st.error("Could not retrieve an embedding for the query.")
84
+ return []
85
+
86
  cos_scores = cosine_similarity([query_embedding], dataset_embeddings)[0]
87
  top_indices = np.argsort(cos_scores)[-k:][::-1]
88
 
89
  retrieved_responses = []
90
  for idx in top_indices:
91
+ retrieved_responses.append(dataset[idx]["response"])
92
  return retrieved_responses
93
 
94
  # Streamlit app UI
 
102
  # Retrieve similar responses from the dataset
103
  retrieved_responses = retrieve_response(user_query, dataset, dataset_embeddings)
104
 
105
+ if retrieved_responses:
106
+ # Join retrieved responses to create a supportive context
107
+ context = " ".join(retrieved_responses)
108
+
109
+ # Generate a supportive response using FLAN-T5 via Groq API
110
+ supportive_response = generate_response(context)
111
+
112
+ if supportive_response:
113
+ st.write("Here's some advice or support for you:")
114
+ st.write(supportive_response)
115
+ else:
116
+ st.write("Sorry, I couldn't generate a response at the moment.")
117
+ else:
118
+ st.write("Sorry, I couldn't find any relevant responses.")