Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,8 +10,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequen
|
|
10 |
sent = "dnzblgn/Sentiment-Analysis-Customer-Reviews"
|
11 |
sarc = "dnzblgn/Sarcasm-Detection-Customer-Reviews"
|
12 |
doc = "dnzblgn/Customer-Reviews-Classification"
|
13 |
-
embedding_model = SentenceTransformer('
|
14 |
-
|
15 |
|
16 |
# Your models (no token, no fast tokenizer)
|
17 |
sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews", use_fast=False)
|
@@ -23,13 +22,9 @@ sarcasm_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sarc
|
|
23 |
classification_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification", use_fast=False)
|
24 |
classification_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification")
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
-
|
29 |
-
raise ValueError("Environment variable 'rag' is not set. Please check your Space secrets.")
|
30 |
-
|
31 |
-
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False, use_auth_token=HF_TOKEN)
|
32 |
-
mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", use_auth_token=HF_TOKEN).eval()
|
33 |
|
34 |
# Paths and files
|
35 |
UPLOAD_FOLDER = "uploads"
|
@@ -133,11 +128,11 @@ def handle_uploaded_file(file):
|
|
133 |
|
134 |
return "File uploaded and processed successfully."
|
135 |
|
136 |
-
def
|
137 |
-
inputs =
|
138 |
with torch.no_grad():
|
139 |
-
outputs =
|
140 |
-
response =
|
141 |
return response
|
142 |
|
143 |
def query_chatbot(query):
|
@@ -153,13 +148,18 @@ def query_chatbot(query):
|
|
153 |
relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)]
|
154 |
context = "\n\n".join(relevant_docs[:top_k])
|
155 |
|
|
|
156 |
final_prompt = (
|
|
|
|
|
|
|
157 |
f"Context:\n{context}\n\n"
|
158 |
f"Question: {query}\n\n"
|
159 |
-
f"Your Answer (based on the context):"
|
160 |
)
|
161 |
|
162 |
-
return
|
|
|
163 |
|
164 |
# Gradio interface
|
165 |
with gr.Blocks() as interface:
|
@@ -190,4 +190,4 @@ with gr.Blocks() as interface:
|
|
190 |
|
191 |
# Run Gradio app
|
192 |
if __name__ == "__main__":
|
193 |
-
interface.launch()
|
|
|
10 |
sent = "dnzblgn/Sentiment-Analysis-Customer-Reviews"
|
11 |
sarc = "dnzblgn/Sarcasm-Detection-Customer-Reviews"
|
12 |
doc = "dnzblgn/Customer-Reviews-Classification"
|
13 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight embedding model for CPU
|
|
|
14 |
|
15 |
# Your models (no token, no fast tokenizer)
|
16 |
sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews", use_fast=False)
|
|
|
22 |
classification_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification", use_fast=False)
|
23 |
classification_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification")
|
24 |
|
25 |
+
# Lightweight Causal Language Model (distilgpt2 instead of Mistral)
|
26 |
+
causal_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
27 |
+
causal_model = AutoModelForCausalLM.from_pretrained("distilgpt2").eval() # Ensure evaluation mode
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# Paths and files
|
30 |
UPLOAD_FOLDER = "uploads"
|
|
|
128 |
|
129 |
return "File uploaded and processed successfully."
|
130 |
|
131 |
+
def causal_generate_response(prompt):
|
132 |
+
inputs = causal_tokenizer(prompt, return_tensors="pt") # Default CPU
|
133 |
with torch.no_grad():
|
134 |
+
outputs = causal_model.generate(inputs["input_ids"], max_length=500, do_sample=True, temperature=0.7)
|
135 |
+
response = causal_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
136 |
return response
|
137 |
|
138 |
def query_chatbot(query):
|
|
|
148 |
relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)]
|
149 |
context = "\n\n".join(relevant_docs[:top_k])
|
150 |
|
151 |
+
# Custom Prompt for RAG
|
152 |
final_prompt = (
|
153 |
+
f"You are a business data analyst. Analyze the feedback data and identify the overall sentiment trends. "
|
154 |
+
f"Focus on determining whether positive feedback or negative feedback dominates in each category, and avoid overstating less significant trends. "
|
155 |
+
f"Provide clear, data-driven insights.\n\n"
|
156 |
f"Context:\n{context}\n\n"
|
157 |
f"Question: {query}\n\n"
|
158 |
+
f"Your Answer (based on the data and context):"
|
159 |
)
|
160 |
|
161 |
+
return causal_generate_response(final_prompt)
|
162 |
+
|
163 |
|
164 |
# Gradio interface
|
165 |
with gr.Blocks() as interface:
|
|
|
190 |
|
191 |
# Run Gradio app
|
192 |
if __name__ == "__main__":
|
193 |
+
interface.launch()
|