adarsh commited on
Commit
3dff25b
1 Parent(s): 0f27aa7
Files changed (2) hide show
  1. app.py +55 -0
  2. requirements.txt +110 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import torch
4
+ from transformers import BertTokenizer, BertForSequenceClassification
5
+
6
+ label_dict = {"Urgency": 0, "Not Dark Pattern": 1, "Scarcity": 2, "Misdirection": 3, "Social Proof": 4, "Obstruction": 5, "Sneaking": 6, "Forced Action": 7}
7
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(label_dict))
8
+ fine_tuned_model_path = "models/finetuned_BERT_epoch_5.model"
9
+ model.load_state_dict(torch.load(fine_tuned_model_path, map_location=torch.device('cpu')))
10
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
11
+
12
+ def get_dark_pattern_name(label):
13
+ reverse_label_dict = {v: k for k, v in label_dict.items()}
14
+ return reverse_label_dict[label]
15
+
16
+ def find_dark_pattern(text_predict):
17
+ encoded_text = tokenizer.encode_plus(
18
+ text_predict,
19
+ add_special_tokens=True,
20
+ return_attention_mask=True,
21
+ pad_to_max_length=True,
22
+ max_length=256,
23
+ return_tensors='pt'
24
+ )
25
+
26
+ model.eval()
27
+
28
+ with torch.no_grad():
29
+ inputs = {
30
+ 'input_ids': encoded_text['input_ids'],
31
+ 'attention_mask': encoded_text['attention_mask']
32
+ }
33
+ outputs = model(**inputs)
34
+
35
+ predictions = outputs.logits
36
+
37
+ probabilities = torch.nn.functional.softmax(predictions, dim=1)
38
+ predicted_label = torch.argmax(probabilities, dim=1).item()
39
+
40
+ return get_dark_pattern_name(predicted_label)
41
+
42
+ def predict(text_to_predict):
43
+ start_time = time.time()
44
+ print("Predicting Dark Pattern...")
45
+ for i in range(10):
46
+ predicted_darkp = find_dark_pattern(text_to_predict)
47
+ time.sleep(0.5)
48
+ end_time = time.time()
49
+ total_time = end_time - start_time
50
+ return f"Result: {predicted_darkp}\nTotal Time Taken: {total_time:.2f} seconds"
51
+
52
+ demo = gr.Interface(fn=predict, inputs="text", outputs="text")
53
+ demo.launch(share=True)
54
+
55
+
requirements.txt ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ asgiref==3.8.1
8
+ attrs==23.2.0
9
+ blinker==1.8.1
10
+ cachetools==5.3.3
11
+ certifi==2024.2.2
12
+ charset-normalizer==3.3.2
13
+ click==8.1.7
14
+ colorama==0.4.6
15
+ contourpy==1.2.1
16
+ cycler==0.12.1
17
+ datasets==2.18.0
18
+ dill==0.3.8
19
+ Django==5.0.3
20
+ dnspython==2.6.1
21
+ email_validator==2.1.1
22
+ fastapi==0.111.0
23
+ fastapi-cli==0.0.2
24
+ ffmpy==0.3.2
25
+ filelock==3.13.4
26
+ Flask==3.0.3
27
+ fonttools==4.51.0
28
+ frozenlist==1.4.1
29
+ fsspec==2024.2.0
30
+ gitdb==4.0.11
31
+ GitPython==3.1.43
32
+ gradio==4.29.0
33
+ gradio_client==0.16.1
34
+ h11==0.14.0
35
+ httpcore==1.0.5
36
+ httptools==0.6.1
37
+ httpx==0.27.0
38
+ huggingface-hub==0.22.2
39
+ idna==3.6
40
+ importlib_resources==6.4.0
41
+ itsdangerous==2.2.0
42
+ Jinja2==3.1.3
43
+ jsonschema==4.21.1
44
+ jsonschema-specifications==2023.12.1
45
+ kiwisolver==1.4.5
46
+ markdown-it-py==3.0.0
47
+ MarkupSafe==2.1.5
48
+ matplotlib==3.8.4
49
+ mdurl==0.1.2
50
+ mpmath==1.3.0
51
+ multidict==6.0.5
52
+ multiprocess==0.70.16
53
+ mysqlclient==2.2.4
54
+ networkx==3.3
55
+ numpy==1.26.4
56
+ orjson==3.10.3
57
+ packaging==24.0
58
+ pandas==2.2.1
59
+ pillow==10.3.0
60
+ protobuf==4.25.3
61
+ pyarrow==15.0.2
62
+ pyarrow-hotfix==0.6
63
+ pydantic==2.7.1
64
+ pydantic_core==2.18.2
65
+ pydeck==0.9.0b1
66
+ pydub==0.25.1
67
+ Pygments==2.17.2
68
+ pyparsing==3.1.2
69
+ python-dateutil==2.9.0.post0
70
+ python-dotenv==1.0.1
71
+ python-multipart==0.0.9
72
+ pytz==2024.1
73
+ PyYAML==6.0.1
74
+ referencing==0.35.0
75
+ regex==2024.4.16
76
+ requests==2.31.0
77
+ rich==13.7.1
78
+ rpds-py==0.18.0
79
+ ruff==0.4.3
80
+ safetensors==0.4.3
81
+ semantic-version==2.10.0
82
+ shellingham==1.5.4
83
+ six==1.16.0
84
+ smmap==5.0.1
85
+ sniffio==1.3.1
86
+ sqlparse==0.4.4
87
+ starlette==0.37.2
88
+ streamlit==1.33.0
89
+ sympy==1.12
90
+ tenacity==8.2.3
91
+ tokenizers==0.19.1
92
+ toml==0.10.2
93
+ tomlkit==0.12.0
94
+ toolz==0.12.1
95
+ torch==2.2.2
96
+ tornado==6.4
97
+ tqdm==4.66.2
98
+ transformers==4.40.1
99
+ typer==0.12.3
100
+ typing_extensions==4.11.0
101
+ tzdata==2024.1
102
+ ujson==5.9.0
103
+ urllib3==2.2.1
104
+ uvicorn==0.29.0
105
+ watchdog==4.0.0
106
+ watchfiles==0.21.0
107
+ websockets==11.0.3
108
+ Werkzeug==3.0.2
109
+ xxhash==3.4.1
110
+ yarl==1.9.4