dnzblgn's picture
Update app.py
3a56bef verified
import gradio as gr
import torch
import faiss
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
# Model paths
sent = "dnzblgn/Sentiment-Analysis-Customer-Reviews"
sarc = "dnzblgn/Sarcasm-Detection-Customer-Reviews"
doc = "dnzblgn/Customer-Reviews-Classification"
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Your models (no token, no fast tokenizer)
sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews", use_fast=False)
sentiment_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews")
sarcasm_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews", use_fast=False)
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews")
classification_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification", use_fast=False)
classification_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification")
# ** Mistral model for RAG **
mistral_model_name = "mistralai/Mistral-7B-v0.1"
causal_tokenizer = AutoTokenizer.from_pretrained(mistral_model_name)
causal_model = AutoModelForCausalLM.from_pretrained(mistral_model_name, torch_dtype=torch.float16).eval()
# Paths and files
UPLOAD_FOLDER = "uploads"
SUMMARY_FILE = "summary.txt"
FAISS_INDEX_PATH = "faiss_index"
DOCUMENTS_FILE = "documents.txt"
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
categories = {
0: "Shipping and Delivery",
1: "Customer Service",
2: "Price and Value",
3: "Quality and Performance",
4: "Use and Design",
5: "Other"
}
# Helper functions
def analyze_sentiment(sentence):
inputs = sentiment_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = sentiment_model(**inputs)
logits = outputs.logits
sentiment = torch.argmax(logits, dim=-1).item()
return "Positive" if sentiment == 0 else "Negative"
def detect_sarcasm(sentence):
inputs = sarcasm_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = sarcasm_model(**inputs)
logits = outputs.logits
sarcasm = torch.argmax(logits, dim=-1).item()
return sarcasm == 1
def classify_document(sentence):
inputs = classification_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
category = torch.argmax(logits, dim=-1).item()
return categories[category]
def preprocess_summary(file_path):
with open(file_path, "r", encoding="utf-8") as file:
lines = file.readlines()
chunks = []
current_chunk = []
for line in lines:
line = line.strip()
if not line:
continue
if line.endswith(":") and current_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
current_chunk.append(line)
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
def create_faiss_index(chunks):
embeddings = [embedding_model.encode(chunk, normalize_embeddings=True) for chunk in chunks]
embeddings_np = np.array(embeddings)
embedding_dimension = embeddings_np.shape[1]
faiss_index = faiss.IndexFlatL2(embedding_dimension)
faiss_index.add(embeddings_np)
faiss.write_index(faiss_index, FAISS_INDEX_PATH)
with open(DOCUMENTS_FILE, "w", encoding="utf-8") as doc_file:
for chunk in chunks:
doc_file.write(chunk + "\n--END--\n")
def handle_uploaded_file(file):
# Save the contents directly from the NamedString
file_path = os.path.join(UPLOAD_FOLDER, "uploaded_comments.txt")
with open(file_path, "w", encoding="utf-8") as f:
f.write(file) # `file` is already the content of the file as a string
with open(file_path, "r", encoding="utf-8") as f:
comments = f.readlines()
results = []
for comment in comments:
comment = comment.strip()
if not comment:
continue
sentiment = analyze_sentiment(comment)
if sentiment == "Positive" and detect_sarcasm(comment):
sentiment = "Negative"
category = classify_document(comment)
results.append({"comment": comment, "sentiment": sentiment, "category": category})
chunks = preprocess_summary(file_path)
create_faiss_index(chunks)
return "File uploaded and processed successfully."
def causal_generate_response(prompt):
inputs = causal_tokenizer(prompt, return_tensors="pt") # Default CPU
with torch.no_grad():
outputs = causal_model.generate(inputs["input_ids"], max_length=500, do_sample=True, temperature=0.7)
response = causal_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def query_chatbot(query):
top_k = 5
faiss_index = faiss.read_index(FAISS_INDEX_PATH)
with open(DOCUMENTS_FILE, "r", encoding="utf-8") as doc_file:
documents = doc_file.read().split("\n--END--\n")
query_embedding = embedding_model.encode([query], normalize_embeddings=True)
distances, indices = faiss_index.search(np.array(query_embedding), top_k)
relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)]
context = "\n\n".join(relevant_docs[:top_k])
# Custom Prompt for RAG
final_prompt = (
f"You are a business data analyst. Analyze the feedback data and identify the overall sentiment trends. "
f"Focus on determining whether positive feedback or negative feedback dominates in each category, and avoid overstating less significant trends. "
f"Provide clear, data-driven insights.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Your Answer (based on the data and context):"
)
return causal_generate_response(final_prompt)
# Gradio interface
with gr.Blocks() as interface:
gr.Markdown("# Sentiment Analysis Powered by Sarcasm Detection")
with gr.Row():
upload = gr.File(label="Upload .txt File")
chatbot_output = gr.Textbox(label="Processing Report", lines=10, interactive=False)
upload_btn = gr.Button("Process File")
with gr.Row():
query_input = gr.Textbox(label="Ask a Question")
answer_output = gr.Textbox(label="Answer", lines=5, interactive=False)
query_btn = gr.Button("Get Answer")
def process_file_and_show_chatbot(file):
result_message = handle_uploaded_file(file)
return result_message
upload_btn.click(process_file_and_show_chatbot, inputs=upload, outputs=chatbot_output)
def handle_query(query):
response = query_chatbot(query)
return response
query_btn.click(handle_query, inputs=query_input, outputs=answer_output)
# Run Gradio app
if __name__ == "__main__":
interface.launch()