|
import streamlit as st |
|
import numpy as np |
|
import transformers |
|
import re |
|
import string |
|
import preprocessor as pre |
|
|
|
import torch |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
|
with open("style.css") as f: |
|
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True) |
|
|
|
|
|
model_path = "ninahf1503/SA-BERTchatgptapp" |
|
tokenizer = BertTokenizer.from_pretrained(model_path) |
|
model = BertForSequenceClassification.from_pretrained(model_path, ignore_mismatched_sizes=True ) |
|
|
|
|
|
seq_max_length = 55 |
|
|
|
|
|
def tokenizing_text(sentence): |
|
sentence = preprocess_text(sentence) |
|
encoded = tokenizer.encode_plus( |
|
sentence, |
|
add_special_tokens=True, |
|
max_length=seq_max_length, |
|
truncation=True, |
|
padding='max_length', |
|
return_tensors='pt' |
|
) |
|
|
|
input_ids = encoded['input_ids'] |
|
attention_mask = encoded['attention_mask'] |
|
return input_ids, attention_mask |
|
|
|
|
|
def preprocess_text(sentence): |
|
re_cleansing = "@\S+|https?:\S+|http?:\S|#[A-Za-z0-9]+|^RT[\s]+|(^|\W)\d+" |
|
for punctuation in string.punctuation: |
|
sentence = sentence.encode().decode('unicode_escape') |
|
sentence = re.sub(r'\n', ' ', sentence) |
|
sentence = pre.clean(sentence) |
|
sentence = re.sub(r'[^\w\s]', ' ', sentence) |
|
sentence = re.sub(r'[0-9]', ' ', sentence) |
|
sentence = re.sub(re_cleansing, ' ', sentence).strip() |
|
sentence = sentence.replace(punctuation, '') |
|
sentence = sentence.lower() |
|
return sentence |
|
|
|
|
|
def predict_sentiment(input_text): |
|
input_ids, attention_mask = tokenizing_text(input_text) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask) |
|
|
|
logits = outputs.logits |
|
predict_class = torch.argmax(logits, dim=1).item() |
|
|
|
label_sentiment = {0: "Bad", 1: "Good", 2: "Neutral"} |
|
predict_label = label_sentiment[predict_class] |
|
|
|
return predict_label |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("Sentimen Analysis", anchor=False) |
|
tweet_text = st.text_area(" ", placeholder="Enter the sentence you want to analyze", label_visibility="collapsed") |
|
|
|
if st.button("SUBMIT"): |
|
if tweet_text.strip() == "": |
|
st.title("Text Input Still Empty", anchor=False) |
|
st.info("Please fill in the sentence you want to analyze") |
|
else: |
|
sentiment = predict_sentiment(tweet_text) |
|
if sentiment == "Good": |
|
st.title("Sentiment Analysis Results", anchor=False) |
|
st.markdown('<div style="background-color: #5d9c59; padding: 16px; border-radius: 5px; font-weight: bold; color:white;">This sentence contains a positive sentiment</div>', unsafe_allow_html=True) |
|
elif sentiment == "Bad": |
|
st.title("Sentiment Analysis Results", anchor=False) |
|
st.markdown('<div style="background-color: #df2e38; padding: 16px; border-radius: 5px; font-weight: bold; color:white;">This sentence contains a negative sentiment</div>', unsafe_allow_html=True) |
|
else: |
|
st.title("Sentiment Analysis Results", anchor=False) |
|
st.markdown('<div style="background-color: #ffa500; padding: 16px; border-radius: 5px; font-weight: bold; color:white;">This sentence is neutral</div>', unsafe_allow_html=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|