Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import faiss | |
import os | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification | |
# Model paths | |
sent = "dnzblgn/Sentiment-Analysis-Customer-Reviews" | |
sarc = "dnzblgn/Sarcasm-Detection-Customer-Reviews" | |
doc = "dnzblgn/Customer-Reviews-Classification" | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Your models (no token, no fast tokenizer) | |
sentiment_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews", use_fast=False) | |
sentiment_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sentiment-Analysis-Customer-Reviews") | |
sarcasm_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews", use_fast=False) | |
sarcasm_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Sarcasm-Detection-Customer-Reviews") | |
classification_tokenizer = AutoTokenizer.from_pretrained("dnzblgn/Customer-Reviews-Classification", use_fast=False) | |
classification_model = AutoModelForSequenceClassification.from_pretrained("dnzblgn/Customer-Reviews-Classification") | |
# ** Mistral model for RAG ** | |
mistral_model_name = "mistralai/Mistral-7B-v0.1" | |
causal_tokenizer = AutoTokenizer.from_pretrained(mistral_model_name) | |
causal_model = AutoModelForCausalLM.from_pretrained(mistral_model_name, torch_dtype=torch.float16).eval() | |
# Paths and files | |
UPLOAD_FOLDER = "uploads" | |
SUMMARY_FILE = "summary.txt" | |
FAISS_INDEX_PATH = "faiss_index" | |
DOCUMENTS_FILE = "documents.txt" | |
if not os.path.exists(UPLOAD_FOLDER): | |
os.makedirs(UPLOAD_FOLDER) | |
categories = { | |
0: "Shipping and Delivery", | |
1: "Customer Service", | |
2: "Price and Value", | |
3: "Quality and Performance", | |
4: "Use and Design", | |
5: "Other" | |
} | |
# Helper functions | |
def analyze_sentiment(sentence): | |
inputs = sentiment_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = sentiment_model(**inputs) | |
logits = outputs.logits | |
sentiment = torch.argmax(logits, dim=-1).item() | |
return "Positive" if sentiment == 0 else "Negative" | |
def detect_sarcasm(sentence): | |
inputs = sarcasm_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = sarcasm_model(**inputs) | |
logits = outputs.logits | |
sarcasm = torch.argmax(logits, dim=-1).item() | |
return sarcasm == 1 | |
def classify_document(sentence): | |
inputs = classification_tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = classification_model(**inputs) | |
logits = outputs.logits | |
category = torch.argmax(logits, dim=-1).item() | |
return categories[category] | |
def preprocess_summary(file_path): | |
with open(file_path, "r", encoding="utf-8") as file: | |
lines = file.readlines() | |
chunks = [] | |
current_chunk = [] | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
if line.endswith(":") and current_chunk: | |
chunks.append("\n".join(current_chunk)) | |
current_chunk = [] | |
current_chunk.append(line) | |
if current_chunk: | |
chunks.append("\n".join(current_chunk)) | |
return chunks | |
def create_faiss_index(chunks): | |
embeddings = [embedding_model.encode(chunk, normalize_embeddings=True) for chunk in chunks] | |
embeddings_np = np.array(embeddings) | |
embedding_dimension = embeddings_np.shape[1] | |
faiss_index = faiss.IndexFlatL2(embedding_dimension) | |
faiss_index.add(embeddings_np) | |
faiss.write_index(faiss_index, FAISS_INDEX_PATH) | |
with open(DOCUMENTS_FILE, "w", encoding="utf-8") as doc_file: | |
for chunk in chunks: | |
doc_file.write(chunk + "\n--END--\n") | |
def handle_uploaded_file(file): | |
# Save the contents directly from the NamedString | |
file_path = os.path.join(UPLOAD_FOLDER, "uploaded_comments.txt") | |
with open(file_path, "w", encoding="utf-8") as f: | |
f.write(file) # `file` is already the content of the file as a string | |
with open(file_path, "r", encoding="utf-8") as f: | |
comments = f.readlines() | |
results = [] | |
for comment in comments: | |
comment = comment.strip() | |
if not comment: | |
continue | |
sentiment = analyze_sentiment(comment) | |
if sentiment == "Positive" and detect_sarcasm(comment): | |
sentiment = "Negative" | |
category = classify_document(comment) | |
results.append({"comment": comment, "sentiment": sentiment, "category": category}) | |
chunks = preprocess_summary(file_path) | |
create_faiss_index(chunks) | |
return "File uploaded and processed successfully." | |
def causal_generate_response(prompt): | |
inputs = causal_tokenizer(prompt, return_tensors="pt") # Default CPU | |
with torch.no_grad(): | |
outputs = causal_model.generate(inputs["input_ids"], max_length=500, do_sample=True, temperature=0.7) | |
response = causal_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
def query_chatbot(query): | |
top_k = 5 | |
faiss_index = faiss.read_index(FAISS_INDEX_PATH) | |
with open(DOCUMENTS_FILE, "r", encoding="utf-8") as doc_file: | |
documents = doc_file.read().split("\n--END--\n") | |
query_embedding = embedding_model.encode([query], normalize_embeddings=True) | |
distances, indices = faiss_index.search(np.array(query_embedding), top_k) | |
relevant_docs = [documents[idx] for idx in indices[0] if idx < len(documents)] | |
context = "\n\n".join(relevant_docs[:top_k]) | |
# Custom Prompt for RAG | |
final_prompt = ( | |
f"You are a business data analyst. Analyze the feedback data and identify the overall sentiment trends. " | |
f"Focus on determining whether positive feedback or negative feedback dominates in each category, and avoid overstating less significant trends. " | |
f"Provide clear, data-driven insights.\n\n" | |
f"Context:\n{context}\n\n" | |
f"Question: {query}\n\n" | |
f"Your Answer (based on the data and context):" | |
) | |
return causal_generate_response(final_prompt) | |
# Gradio interface | |
with gr.Blocks() as interface: | |
gr.Markdown("# Sentiment Analysis Powered by Sarcasm Detection") | |
with gr.Row(): | |
upload = gr.File(label="Upload .txt File") | |
chatbot_output = gr.Textbox(label="Processing Report", lines=10, interactive=False) | |
upload_btn = gr.Button("Process File") | |
with gr.Row(): | |
query_input = gr.Textbox(label="Ask a Question") | |
answer_output = gr.Textbox(label="Answer", lines=5, interactive=False) | |
query_btn = gr.Button("Get Answer") | |
def process_file_and_show_chatbot(file): | |
result_message = handle_uploaded_file(file) | |
return result_message | |
upload_btn.click(process_file_and_show_chatbot, inputs=upload, outputs=chatbot_output) | |
def handle_query(query): | |
response = query_chatbot(query) | |
return response | |
query_btn.click(handle_query, inputs=query_input, outputs=answer_output) | |
# Run Gradio app | |
if __name__ == "__main__": | |
interface.launch() |