dnzblgn commited on
Commit
21a7915
·
verified ·
1 Parent(s): e391337

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -10,8 +10,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequen
10
  sent = "dnzblgn/Sentiment-Analysis-Customer-Reviews"
11
  sarc = "dnzblgn/Sarcasm-Detection-Customer-Reviews"
12
  doc = "dnzblgn/Customer-Reviews-Classification"
13
- embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
14
-
15
 
16
  # Your models (no token, no fast tokenizer)
17
  sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews", use_fast=False)
@@ -23,13 +22,9 @@ sarcasm_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sarc
23
  classification_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification", use_fast=False)
24
  classification_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification")
25
 
26
- # Mistral model (requires token, must be authenticated)
27
- HF_TOKEN = os.getenv("rag") # Using the secret from your Hugging Face Space
28
- if HF_TOKEN is None:
29
- raise ValueError("Environment variable 'rag' is not set. Please check your Space secrets.")
30
-
31
- mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False, use_auth_token=HF_TOKEN)
32
- mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", use_auth_token=HF_TOKEN).eval()
33
 
34
  # Paths and files
35
  UPLOAD_FOLDER = "uploads"
@@ -133,11 +128,11 @@ def handle_uploaded_file(file):
133
 
134
  return "File uploaded and processed successfully."
135
 
136
- def mistral_generate_response(prompt):
137
- inputs = mistral_tokenizer(prompt, return_tensors="pt") # Default is CPU
138
  with torch.no_grad():
139
- outputs = mistral_model.generate(inputs["input_ids"], max_length=500, do_sample=True, temperature=0.7)
140
- response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
141
  return response
142
 
143
  def query_chatbot(query):
@@ -153,13 +148,18 @@ def query_chatbot(query):
153
  relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)]
154
  context = "\n\n".join(relevant_docs[:top_k])
155
 
 
156
  final_prompt = (
 
 
 
157
  f"Context:\n{context}\n\n"
158
  f"Question: {query}\n\n"
159
- f"Your Answer (based on the context):"
160
  )
161
 
162
- return mistral_generate_response(final_prompt)
 
163
 
164
  # Gradio interface
165
  with gr.Blocks() as interface:
@@ -190,4 +190,4 @@ with gr.Blocks() as interface:
190
 
191
  # Run Gradio app
192
  if __name__ == "__main__":
193
- interface.launch()
 
10
  sent = "dnzblgn/Sentiment-Analysis-Customer-Reviews"
11
  sarc = "dnzblgn/Sarcasm-Detection-Customer-Reviews"
12
  doc = "dnzblgn/Customer-Reviews-Classification"
13
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight embedding model for CPU
 
14
 
15
  # Your models (no token, no fast tokenizer)
16
  sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews", use_fast=False)
 
22
  classification_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification", use_fast=False)
23
  classification_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification")
24
 
25
+ # Lightweight Causal Language Model (distilgpt2 instead of Mistral)
26
+ causal_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
27
+ causal_model = AutoModelForCausalLM.from_pretrained("distilgpt2").eval() # Ensure evaluation mode
 
 
 
 
28
 
29
  # Paths and files
30
  UPLOAD_FOLDER = "uploads"
 
128
 
129
  return "File uploaded and processed successfully."
130
 
131
+ def causal_generate_response(prompt):
132
+ inputs = causal_tokenizer(prompt, return_tensors="pt") # Default CPU
133
  with torch.no_grad():
134
+ outputs = causal_model.generate(inputs["input_ids"], max_length=500, do_sample=True, temperature=0.7)
135
+ response = causal_tokenizer.decode(outputs[0], skip_special_tokens=True)
136
  return response
137
 
138
  def query_chatbot(query):
 
148
  relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)]
149
  context = "\n\n".join(relevant_docs[:top_k])
150
 
151
+ # Custom Prompt for RAG
152
  final_prompt = (
153
+ f"You are a business data analyst. Analyze the feedback data and identify the overall sentiment trends. "
154
+ f"Focus on determining whether positive feedback or negative feedback dominates in each category, and avoid overstating less significant trends. "
155
+ f"Provide clear, data-driven insights.\n\n"
156
  f"Context:\n{context}\n\n"
157
  f"Question: {query}\n\n"
158
+ f"Your Answer (based on the data and context):"
159
  )
160
 
161
+ return causal_generate_response(final_prompt)
162
+
163
 
164
  # Gradio interface
165
  with gr.Blocks() as interface:
 
190
 
191
  # Run Gradio app
192
  if __name__ == "__main__":
193
+ interface.launch()