Spaces:
Runtime error
Runtime error
khaiphan29
commited on
Commit
•
0217fc8
1
Parent(s):
0af052d
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- Dockerfile +14 -0
- README.md +4 -6
- main.py +77 -0
- requirements.txt +63 -0
- script.py +5 -0
- src/.DS_Store +0 -0
- src/__init__.py +0 -0
- src/crawler.py +256 -0
- src/mDeBERTa (ft) V6/.DS_Store +0 -0
- src/mDeBERTa (ft) V6/cls.pt +3 -0
- src/mDeBERTa (ft) V6/cls_log.txt +76 -0
- src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-mean/config.json +45 -0
- src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-mean/model.safetensors +3 -0
- src/mDeBERTa (ft) V6/mean.pt +3 -0
- src/mDeBERTa (ft) V6/mean_log.txt +76 -0
- src/mDeBERTa (ft) V6/plot.png +0 -0
- src/mDeBERTa (ft) V6/public_train_v4.json +3 -0
- src/myNLI.py +190 -0
- src/nli_v3.py +115 -0
- src/utils.py +12 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
WORKDIR /code
|
7 |
+
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
11 |
+
|
12 |
+
COPY . .
|
13 |
+
|
14 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
---
|
2 |
-
title: Fact
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: blue
|
6 |
-
sdk:
|
7 |
-
sdk_version: 4.12.0
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
1 |
---
|
2 |
+
title: Fact Checking Api
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: pink
|
5 |
colorTo: blue
|
6 |
+
sdk: docker
|
|
|
|
|
7 |
pinned: false
|
8 |
---
|
9 |
|
main.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#uvicorn main:app --reload
|
2 |
+
from fastapi import FastAPI, status
|
3 |
+
from fastapi.responses import Response, JSONResponse
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
|
12 |
+
from src.myNLI import FactChecker
|
13 |
+
from src.crawler import MyCrawler
|
14 |
+
|
15 |
+
#request body
|
16 |
+
class Claim(BaseModel):
|
17 |
+
claim: str
|
18 |
+
|
19 |
+
class ScrapeBase(BaseModel):
|
20 |
+
id: int
|
21 |
+
name: str
|
22 |
+
scraping_url: str
|
23 |
+
|
24 |
+
class ScrapeList(BaseModel):
|
25 |
+
data: List[ScrapeBase]
|
26 |
+
|
27 |
+
app = FastAPI()
|
28 |
+
|
29 |
+
# load model
|
30 |
+
t_0 = time.time()
|
31 |
+
fact_checker = FactChecker()
|
32 |
+
t_load = time.time() - t_0
|
33 |
+
print("time load model: {}".format(t_load))
|
34 |
+
|
35 |
+
crawler = MyCrawler()
|
36 |
+
|
37 |
+
label_code = {
|
38 |
+
"REFUTED": 0,
|
39 |
+
"SUPPORTED": 1,
|
40 |
+
"NEI": 2
|
41 |
+
}
|
42 |
+
|
43 |
+
@app.get("/")
|
44 |
+
async def root():
|
45 |
+
return {"msg": "This is for interacting with Fact-checking AI Model"}
|
46 |
+
|
47 |
+
@app.post("/ai-fact-check")
|
48 |
+
async def get_claim(req: Claim):
|
49 |
+
claim = req.claim
|
50 |
+
result = fact_checker.predict(claim)
|
51 |
+
print(result)
|
52 |
+
|
53 |
+
if not result:
|
54 |
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
55 |
+
|
56 |
+
return { "claim": claim,
|
57 |
+
"final_label": label_code[result["label"]],
|
58 |
+
"evidence": result["evidence"],
|
59 |
+
"provider": result["provider"],
|
60 |
+
"url": result["url"]
|
61 |
+
}
|
62 |
+
|
63 |
+
@app.post("/scraping-check")
|
64 |
+
async def get_claim(req: ScrapeList):
|
65 |
+
response = []
|
66 |
+
for ele in req.data:
|
67 |
+
response.append({
|
68 |
+
"id": ele.id,
|
69 |
+
"name": ele.name,
|
70 |
+
"scraping_url": ele.scraping_url,
|
71 |
+
"status": crawler.scraping(ele.scraping_url)
|
72 |
+
})
|
73 |
+
|
74 |
+
|
75 |
+
return JSONResponse({
|
76 |
+
"list": response
|
77 |
+
})
|
requirements.txt
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.9.1
|
2 |
+
aiosignal==1.3.1
|
3 |
+
annotated-types==0.6.0
|
4 |
+
anyio==4.2.0
|
5 |
+
async-timeout==4.0.3
|
6 |
+
attrs==23.2.0
|
7 |
+
beautifulsoup4==4.12.2
|
8 |
+
certifi==2023.11.17
|
9 |
+
charset-normalizer==3.3.2
|
10 |
+
click==8.1.7
|
11 |
+
datasets==2.16.1
|
12 |
+
dill==0.3.7
|
13 |
+
exceptiongroup==1.2.0
|
14 |
+
fastapi==0.108.0
|
15 |
+
filelock==3.13.1
|
16 |
+
frozenlist==1.4.1
|
17 |
+
fsspec==2023.10.0
|
18 |
+
h11==0.14.0
|
19 |
+
huggingface-hub==0.20.1
|
20 |
+
idna==3.6
|
21 |
+
Jinja2==3.1.2
|
22 |
+
joblib==1.3.2
|
23 |
+
MarkupSafe==2.1.3
|
24 |
+
mpmath==1.3.0
|
25 |
+
multidict==6.0.4
|
26 |
+
multiprocess==0.70.15
|
27 |
+
networkx==3.2.1
|
28 |
+
nltk==3.8.1
|
29 |
+
numpy==1.26.2
|
30 |
+
packaging==23.2
|
31 |
+
pandas==2.1.4
|
32 |
+
Pillow==10.1.0
|
33 |
+
pyarrow==14.0.2
|
34 |
+
pyarrow-hotfix==0.6
|
35 |
+
pydantic==2.5.3
|
36 |
+
pydantic_core==2.14.6
|
37 |
+
python-dateutil==2.8.2
|
38 |
+
pytz==2023.3.post1
|
39 |
+
PyYAML==6.0.1
|
40 |
+
regex==2023.12.25
|
41 |
+
requests==2.31.0
|
42 |
+
safetensors==0.4.1
|
43 |
+
scikit-learn==1.3.2
|
44 |
+
scipy==1.11.4
|
45 |
+
sentence-transformers==2.2.2
|
46 |
+
sentencepiece==0.1.99
|
47 |
+
six==1.16.0
|
48 |
+
sniffio==1.3.0
|
49 |
+
soupsieve==2.5
|
50 |
+
starlette==0.32.0.post1
|
51 |
+
sympy==1.12
|
52 |
+
threadpoolctl==3.2.0
|
53 |
+
tokenizers==0.15.0
|
54 |
+
torch==2.1.2
|
55 |
+
torchvision==0.16.2
|
56 |
+
tqdm==4.66.1
|
57 |
+
transformers==4.36.2
|
58 |
+
typing_extensions==4.9.0
|
59 |
+
tzdata==2023.4
|
60 |
+
urllib3==2.1.0
|
61 |
+
uvicorn==0.25.0
|
62 |
+
xxhash==3.4.1
|
63 |
+
yarl==1.9.4
|
script.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
api.upload_folder(
|
2 |
+
folder_path="./src",
|
3 |
+
repo_id="khaiphan29/fact-check-api",
|
4 |
+
repo_type="space",
|
5 |
+
)
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/__init__.py
ADDED
File without changes
|
src/crawler.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
|
6 |
+
from .utils import timer_func
|
7 |
+
|
8 |
+
def remove_emoji(string):
|
9 |
+
emoji_pattern = re.compile("["
|
10 |
+
u"\U0001F300-\U0001FAD6" # emoticons
|
11 |
+
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
12 |
+
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
13 |
+
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
14 |
+
u"\U00002702-\U000027B0"
|
15 |
+
u"\U000024C2-\U0001F251"
|
16 |
+
"]+", flags=re.UNICODE)
|
17 |
+
return emoji_pattern.sub(r'', string)
|
18 |
+
|
19 |
+
def preprocess(texts):
|
20 |
+
texts = [text.replace("_", " ") for text in texts]
|
21 |
+
texts = [i.lower() for i in texts]
|
22 |
+
texts = [remove_emoji(i) for i in texts]
|
23 |
+
|
24 |
+
texts = [re.sub('[^\w\d\s]', '', i) for i in texts]
|
25 |
+
|
26 |
+
texts = [re.sub('\s+|\n', ' ', i) for i in texts]
|
27 |
+
texts = [re.sub('^\s|\s$', '', i) for i in texts]
|
28 |
+
|
29 |
+
# texts = [ViTokenizer.tokenize(i) for i in texts]
|
30 |
+
|
31 |
+
return texts
|
32 |
+
|
33 |
+
|
34 |
+
class MyCrawler:
|
35 |
+
headers = {
|
36 |
+
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.67 Safari/537.36",
|
37 |
+
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
|
38 |
+
'Accept-Language': 'en-US,en;q=0.5',
|
39 |
+
'Accept-Encoding': 'gzip, deflate',
|
40 |
+
'DNT': '1',
|
41 |
+
'Connection': 'keep-alive',
|
42 |
+
'Upgrade-Insecure-Requests': '1'
|
43 |
+
}
|
44 |
+
|
45 |
+
# headers = {
|
46 |
+
# 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:49.0) Gecko/20100101 Firefox/49.0',
|
47 |
+
# # 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
|
48 |
+
# # 'Accept-Language': 'en-US,en;q=0.5',
|
49 |
+
# # 'Accept-Encoding': 'gzip, deflate',
|
50 |
+
# # 'DNT': '1',
|
51 |
+
# # 'Connection': 'keep-alive',
|
52 |
+
# # 'Upgrade-Insecure-Requests': '1'
|
53 |
+
# }
|
54 |
+
|
55 |
+
def getSoup(self, url: str):
|
56 |
+
req = requests.get(url,headers=self.headers)
|
57 |
+
return BeautifulSoup(req.text, 'html.parser')
|
58 |
+
|
59 |
+
def crawl_byContainer(self, url: str, article_container: str, body_class: str):
|
60 |
+
soup = self.getSoup(url)
|
61 |
+
|
62 |
+
paragraphs = soup.find(article_container,{"class": body_class})
|
63 |
+
if paragraphs:
|
64 |
+
#Crawl all paragraphs
|
65 |
+
contents = []
|
66 |
+
numOfParagraphs = 0
|
67 |
+
for p in paragraphs.find_all("p"):
|
68 |
+
contents.append(p.get_text())
|
69 |
+
numOfParagraphs += 1
|
70 |
+
# if numOfParagraphs > 10:
|
71 |
+
# break
|
72 |
+
|
73 |
+
if contents:
|
74 |
+
result = "\n".join(contents)
|
75 |
+
if (url.split("/")[2] == "vnexpress.net"):
|
76 |
+
result = self.crawl_byElement(soup, "p", "description") + "\n" + result
|
77 |
+
|
78 |
+
return result
|
79 |
+
return ""
|
80 |
+
|
81 |
+
def crawl_byElement(self, soup, element: str, ele_class: str):
|
82 |
+
print("by Elements...")
|
83 |
+
|
84 |
+
paragraph = soup.find(element,{"class": ele_class})
|
85 |
+
if paragraph:
|
86 |
+
print(paragraph.get_text())
|
87 |
+
return paragraph.get_text()
|
88 |
+
return ""
|
89 |
+
|
90 |
+
def crawl_webcontent(self, url: str):
|
91 |
+
|
92 |
+
provider = url.split("/")[2]
|
93 |
+
content = ""
|
94 |
+
|
95 |
+
if provider == "thanhnien.vn" or provider == "tuoitre.vn":
|
96 |
+
content = self.crawl_byContainer(url, "div", "afcbc-body")
|
97 |
+
elif provider == "vietnamnet.vn":
|
98 |
+
content = self.crawl_byContainer(url, "div", "maincontent")
|
99 |
+
elif provider == "vnexpress.net":
|
100 |
+
content = self.crawl_byContainer(url, "article", "fck_detail")
|
101 |
+
elif provider == "www.24h.com.vn":
|
102 |
+
content = self.crawl_byContainer(url, "article", "cate-24h-foot-arti-deta-info")
|
103 |
+
elif provider == "vov.vn":
|
104 |
+
content = self.crawl_byContainer(url, "div", "article-content")
|
105 |
+
elif provider == "vtv.vn":
|
106 |
+
content = self.crawl_byContainer(url, "div", "ta-justify")
|
107 |
+
elif provider == "vi.wikipedia.org":
|
108 |
+
content = self.crawl_byContainer(url, "div", "mw-content-ltr")
|
109 |
+
elif provider == "www.vinmec.com":
|
110 |
+
content = self.crawl_byContainer(url, "div", "block-content")
|
111 |
+
|
112 |
+
elif provider == "vietstock.vn":
|
113 |
+
content = self.crawl_byContainer(url, "div", "single_post_heading")
|
114 |
+
elif provider == "vneconomy.vn":
|
115 |
+
content = self.crawl_byContainer(url, "article", "detail-wrap")
|
116 |
+
|
117 |
+
elif provider == "dantri.com.vn":
|
118 |
+
content = self.crawl_byContainer(url, "article", "singular-container")
|
119 |
+
|
120 |
+
# elif provider == "plo.vn":
|
121 |
+
# content = self.crawl_byContainer(url, "div", "article__body")
|
122 |
+
|
123 |
+
return provider, url, content
|
124 |
+
|
125 |
+
#def crawl_redir(url):
|
126 |
+
|
127 |
+
@timer_func
|
128 |
+
def search(self, claim: str, count: int = 1):
|
129 |
+
processed_claim = preprocess([claim])[0]
|
130 |
+
|
131 |
+
num_words = 100
|
132 |
+
ls_word = processed_claim.split(" ")
|
133 |
+
claim_short = " ".join(ls_word[:num_words])
|
134 |
+
|
135 |
+
print(claim_short)
|
136 |
+
query = claim_short
|
137 |
+
# query = '+'.join(claim_short.split(" "))
|
138 |
+
|
139 |
+
try:
|
140 |
+
|
141 |
+
# print(soup.prettify())
|
142 |
+
|
143 |
+
#get all URLs
|
144 |
+
attemp_time = 0
|
145 |
+
urls = []
|
146 |
+
while len(urls) == 0 and attemp_time < 3:
|
147 |
+
req=requests.get("https://www.bing.com/search?", headers=self.headers, params={
|
148 |
+
"q": query,
|
149 |
+
"responseFilter":"-images",
|
150 |
+
"responseFilter":"-videos"
|
151 |
+
})
|
152 |
+
print("Query URL: " + req.url)
|
153 |
+
|
154 |
+
print("Crawling Attempt " + str(attemp_time))
|
155 |
+
soup = BeautifulSoup(req.text, 'html.parser')
|
156 |
+
|
157 |
+
completeData = soup.find_all("li",{"class":"b_algo"})
|
158 |
+
for data in completeData:
|
159 |
+
urls.append(data.find("a", href=True)["href"])
|
160 |
+
attemp_time += 1
|
161 |
+
time.sleep(1)
|
162 |
+
|
163 |
+
print("Got " + str(len(urls)) + " urls")
|
164 |
+
|
165 |
+
result = []
|
166 |
+
|
167 |
+
for url in urls:
|
168 |
+
print("Crawling... " + url)
|
169 |
+
provider, url, content = self.crawl_webcontent(url)
|
170 |
+
|
171 |
+
if content:
|
172 |
+
result.append({
|
173 |
+
"provider": provider,
|
174 |
+
"url": url,
|
175 |
+
"content": content
|
176 |
+
})
|
177 |
+
count -= 1
|
178 |
+
if count == 0:
|
179 |
+
break
|
180 |
+
|
181 |
+
return result
|
182 |
+
|
183 |
+
except Exception as e:
|
184 |
+
print(e)
|
185 |
+
return []
|
186 |
+
|
187 |
+
@timer_func
|
188 |
+
def searchGoogle(self, claim: str, count: int = 1):
|
189 |
+
processed_claim = preprocess([claim])[0]
|
190 |
+
|
191 |
+
num_words = 100
|
192 |
+
ls_word = processed_claim.split(" ")
|
193 |
+
claim_short = " ".join(ls_word[:num_words])
|
194 |
+
|
195 |
+
print(claim_short)
|
196 |
+
query = claim_short
|
197 |
+
# query = '+'.join(claim_short.split(" "))
|
198 |
+
|
199 |
+
try:
|
200 |
+
|
201 |
+
# print(soup.prettify())
|
202 |
+
|
203 |
+
#get all URLs
|
204 |
+
attemp_time = 0
|
205 |
+
urls = []
|
206 |
+
while len(urls) == 0 and attemp_time < 3:
|
207 |
+
req=requests.get("https://www.google.com/search?", headers=self.headers, params={
|
208 |
+
"q": query
|
209 |
+
})
|
210 |
+
print("Query URL: " + req.url)
|
211 |
+
|
212 |
+
print("Crawling Attempt " + str(attemp_time))
|
213 |
+
soup = BeautifulSoup(req.text, 'html.parser')
|
214 |
+
|
215 |
+
completeData = soup.find_all("a",{"jsname":"UWckNb"})
|
216 |
+
for data in completeData:
|
217 |
+
urls.append(data["href"])
|
218 |
+
attemp_time += 1
|
219 |
+
time.sleep(1)
|
220 |
+
|
221 |
+
print("Got " + str(len(urls)) + " urls")
|
222 |
+
|
223 |
+
result = []
|
224 |
+
|
225 |
+
for url in urls:
|
226 |
+
print("Crawling... " + url)
|
227 |
+
provider, url, content = self.crawl_webcontent(url)
|
228 |
+
|
229 |
+
if content:
|
230 |
+
result.append({
|
231 |
+
"provider": provider,
|
232 |
+
"url": url,
|
233 |
+
"content": content
|
234 |
+
})
|
235 |
+
count -= 1
|
236 |
+
if count == 0:
|
237 |
+
break
|
238 |
+
|
239 |
+
return result
|
240 |
+
|
241 |
+
except Exception as e:
|
242 |
+
print(e)
|
243 |
+
return []
|
244 |
+
|
245 |
+
@timer_func
|
246 |
+
def scraping(self, url: str):
|
247 |
+
try:
|
248 |
+
provider, url, content = self.crawl_webcontent(url)
|
249 |
+
|
250 |
+
if content:
|
251 |
+
return True
|
252 |
+
return False
|
253 |
+
|
254 |
+
except Exception as e:
|
255 |
+
print(e)
|
256 |
+
return False
|
src/mDeBERTa (ft) V6/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/mDeBERTa (ft) V6/cls.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1c3c8eae44569fd01a746b220091611125f9eb04e09af2d60a6d80befcdb769
|
3 |
+
size 11064
|
src/mDeBERTa (ft) V6/cls_log.txt
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Step 0 -- Accuracy: 0.3039772727272727 -- macro_f1: 0.20810584530698015 -- loss: 1.0453389883041382
|
3 |
+
|
4 |
+
Step 100 -- Accuracy: 0.859375 -- macro_f1: 0.8598470398571504 -- loss: 0.11795929819345474
|
5 |
+
|
6 |
+
Step 200 -- Accuracy: 0.8747159090909091 -- macro_f1: 0.8755251824421424 -- loss: 0.22730453312397003
|
7 |
+
|
8 |
+
Step 300 -- Accuracy: 0.8536931818181818 -- macro_f1: 0.8533303214529117 -- loss: 0.18725647032260895
|
9 |
+
|
10 |
+
Step 400 -- Accuracy: 0.8690340909090909 -- macro_f1: 0.8687299763460793 -- loss: 0.28860458731651306
|
11 |
+
|
12 |
+
Step 500 -- Accuracy: 0.8798295454545455 -- macro_f1: 0.8802316356122608 -- loss: 0.6372634172439575
|
13 |
+
|
14 |
+
Step 600 -- Accuracy: 0.8610795454545455 -- macro_f1: 0.8612099869711884 -- loss: 0.41530805826187134
|
15 |
+
|
16 |
+
Step 700 -- Accuracy: 0.8491477272727272 -- macro_f1: 0.849751664990205 -- loss: 0.5970628261566162
|
17 |
+
|
18 |
+
Step 800 -- Accuracy: 0.8764204545454546 -- macro_f1: 0.8766266441048876 -- loss: 0.2515469491481781
|
19 |
+
|
20 |
+
Step 900 -- Accuracy: 0.8710227272727272 -- macro_f1: 0.8712350728851791 -- loss: 0.619756817817688
|
21 |
+
|
22 |
+
Step 1000 -- Accuracy: 0.8744318181818181 -- macro_f1: 0.8746062203201398 -- loss: 0.5634986758232117
|
23 |
+
|
24 |
+
Step 1100 -- Accuracy: 0.8735795454545454 -- macro_f1: 0.8735921715063891 -- loss: 0.2514641284942627
|
25 |
+
|
26 |
+
Step 1200 -- Accuracy: 0.8375 -- macro_f1: 0.8368621880475362 -- loss: 0.44521981477737427
|
27 |
+
|
28 |
+
Step 1300 -- Accuracy: 0.8551136363636364 -- macro_f1: 0.8555806721970362 -- loss: 0.048632219433784485
|
29 |
+
|
30 |
+
Step 1400 -- Accuracy: 0.8508522727272727 -- macro_f1: 0.8506097642423027 -- loss: 0.24613773822784424
|
31 |
+
|
32 |
+
Step 1500 -- Accuracy: 0.8673295454545454 -- macro_f1: 0.8671847303392856 -- loss: 0.1494443565607071
|
33 |
+
|
34 |
+
Step 1600 -- Accuracy: 0.834375 -- macro_f1: 0.8342641066244109 -- loss: 0.17161081731319427
|
35 |
+
|
36 |
+
Step 1700 -- Accuracy: 0.865625 -- macro_f1: 0.8651594643017528 -- loss: 0.154042050242424
|
37 |
+
|
38 |
+
Step 1800 -- Accuracy: 0.865909090909091 -- macro_f1: 0.8657615265484808 -- loss: 0.1435176134109497
|
39 |
+
|
40 |
+
Step 1900 -- Accuracy: 0.8176136363636364 -- macro_f1: 0.8171586288909666 -- loss: 0.09292535483837128
|
41 |
+
|
42 |
+
Step 2000 -- Accuracy: 0.8440340909090909 -- macro_f1: 0.843042759250924 -- loss: 0.34320467710494995
|
43 |
+
|
44 |
+
Step 2100 -- Accuracy: 0.8428977272727273 -- macro_f1: 0.8428498174495328 -- loss: 0.5764151811599731
|
45 |
+
|
46 |
+
Step 2200 -- Accuracy: 0.8417613636363637 -- macro_f1: 0.8418818479059557 -- loss: 0.28757143020629883
|
47 |
+
|
48 |
+
Step 2300 -- Accuracy: 0.840625 -- macro_f1: 0.8406394626850148 -- loss: 0.8960273861885071
|
49 |
+
|
50 |
+
Step 2400 -- Accuracy: 0.8142045454545455 -- macro_f1: 0.8140964442024906 -- loss: 0.8550783395767212
|
51 |
+
|
52 |
+
Step 2500 -- Accuracy: 0.8144886363636363 -- macro_f1: 0.8147455224461172 -- loss: 0.39625313878059387
|
53 |
+
|
54 |
+
Step 2600 -- Accuracy: 0.8053977272727273 -- macro_f1: 0.8021211300036969 -- loss: 0.3774358034133911
|
55 |
+
|
56 |
+
Step 2700 -- Accuracy: 0.8292613636363636 -- macro_f1: 0.8292382309283113 -- loss: 0.16644884645938873
|
57 |
+
|
58 |
+
Step 2800 -- Accuracy: 0.8150568181818182 -- macro_f1: 0.814290740222007 -- loss: 0.237399160861969
|
59 |
+
|
60 |
+
Step 2900 -- Accuracy: 0.8107954545454545 -- macro_f1: 0.8111709474507229 -- loss: 0.5621077418327332
|
61 |
+
|
62 |
+
Step 3000 -- Accuracy: 0.7926136363636364 -- macro_f1: 0.7930916669737708 -- loss: 0.4253169298171997
|
63 |
+
|
64 |
+
Step 3100 -- Accuracy: 0.8099431818181818 -- macro_f1: 0.8102288703246834 -- loss: 0.43165838718414307
|
65 |
+
|
66 |
+
Step 3200 -- Accuracy: 0.772159090909091 -- macro_f1: 0.7717788019596861 -- loss: 0.673878014087677
|
67 |
+
|
68 |
+
Step 3300 -- Accuracy: 0.7897727272727273 -- macro_f1: 0.7895567869064662 -- loss: 0.1990412026643753
|
69 |
+
|
70 |
+
Step 3400 -- Accuracy: 0.8008522727272728 -- macro_f1: 0.7997998535844976 -- loss: 0.4523601531982422
|
71 |
+
|
72 |
+
Step 3500 -- Accuracy: 0.7798295454545454 -- macro_f1: 0.7780260696858295 -- loss: 0.8848648071289062
|
73 |
+
|
74 |
+
Step 3600 -- Accuracy: 0.7775568181818182 -- macro_f1: 0.7779453966289696 -- loss: 0.5041539669036865
|
75 |
+
|
76 |
+
Step 3700 -- Accuracy: 0.709659090909091 -- macro_f1: 0.7069128111001839 -- loss: 0.6758942604064941
|
src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-mean/config.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/content/checkpoint",
|
3 |
+
"architectures": [
|
4 |
+
"DebertaV2Model"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"id2label": {
|
11 |
+
"0": "entailment",
|
12 |
+
"1": "neutral",
|
13 |
+
"2": "contradiction"
|
14 |
+
},
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"intermediate_size": 3072,
|
17 |
+
"label2id": {
|
18 |
+
"contradiction": 2,
|
19 |
+
"entailment": 0,
|
20 |
+
"neutral": 1
|
21 |
+
},
|
22 |
+
"layer_norm_eps": 1e-07,
|
23 |
+
"max_position_embeddings": 512,
|
24 |
+
"max_relative_positions": -1,
|
25 |
+
"model_type": "deberta-v2",
|
26 |
+
"norm_rel_ebd": "layer_norm",
|
27 |
+
"num_attention_heads": 12,
|
28 |
+
"num_hidden_layers": 12,
|
29 |
+
"pad_token_id": 0,
|
30 |
+
"pooler_dropout": 0,
|
31 |
+
"pooler_hidden_act": "gelu",
|
32 |
+
"pooler_hidden_size": 768,
|
33 |
+
"pos_att_type": [
|
34 |
+
"p2c",
|
35 |
+
"c2p"
|
36 |
+
],
|
37 |
+
"position_biased_input": false,
|
38 |
+
"position_buckets": 256,
|
39 |
+
"relative_attention": true,
|
40 |
+
"share_att_key": true,
|
41 |
+
"torch_dtype": "float32",
|
42 |
+
"transformers_version": "4.35.0",
|
43 |
+
"type_vocab_size": 0,
|
44 |
+
"vocab_size": 251000
|
45 |
+
}
|
src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-mean/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c7e80e8237ad2969b1c989d71f97fa7b950fd239bfa8b3329f0535a0b8a2aca
|
3 |
+
size 1112897768
|
src/mDeBERTa (ft) V6/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f963dfcdad5469498af3b396c5af0e27365e59a01498c51896b9e6547851cd4
|
3 |
+
size 11071
|
src/mDeBERTa (ft) V6/mean_log.txt
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Step 0 -- Accuracy: 0.275 -- macro_f1: 0.24245894645844043 -- loss: 1.1975505352020264
|
3 |
+
|
4 |
+
Step 100 -- Accuracy: 0.8230113636363636 -- macro_f1: 0.8247917227891541 -- loss: 0.5072745084762573
|
5 |
+
|
6 |
+
Step 200 -- Accuracy: 0.8585227272727273 -- macro_f1: 0.8596474113005192 -- loss: 0.3576969504356384
|
7 |
+
|
8 |
+
Step 300 -- Accuracy: 0.8616477272727273 -- macro_f1: 0.8619445917534628 -- loss: 0.22678352892398834
|
9 |
+
|
10 |
+
Step 400 -- Accuracy: 0.8710227272727272 -- macro_f1: 0.8713149438253084 -- loss: 0.3302939534187317
|
11 |
+
|
12 |
+
Step 500 -- Accuracy: 0.8491477272727272 -- macro_f1: 0.8497535984618637 -- loss: 0.8534196615219116
|
13 |
+
|
14 |
+
Step 600 -- Accuracy: 0.8627840909090909 -- macro_f1: 0.8630171351987245 -- loss: 0.27207863330841064
|
15 |
+
|
16 |
+
Step 700 -- Accuracy: 0.8676136363636363 -- macro_f1: 0.8681189318753203 -- loss: 0.5472040772438049
|
17 |
+
|
18 |
+
Step 800 -- Accuracy: 0.8480113636363636 -- macro_f1: 0.8474828960740969 -- loss: 0.20389704406261444
|
19 |
+
|
20 |
+
Step 900 -- Accuracy: 0.8625 -- macro_f1: 0.8627369387200629 -- loss: 0.7003616094589233
|
21 |
+
|
22 |
+
Step 1000 -- Accuracy: 0.8471590909090909 -- macro_f1: 0.8474576933366409 -- loss: 0.39897170662879944
|
23 |
+
|
24 |
+
Step 1100 -- Accuracy: 0.8647727272727272 -- macro_f1: 0.8648449015557045 -- loss: 0.30028393864631653
|
25 |
+
|
26 |
+
Step 1200 -- Accuracy: 0.8355113636363637 -- macro_f1: 0.8357176579844655 -- loss: 0.5329824090003967
|
27 |
+
|
28 |
+
Step 1300 -- Accuracy: 0.8318181818181818 -- macro_f1: 0.832158484567787 -- loss: 0.04946904629468918
|
29 |
+
|
30 |
+
Step 1400 -- Accuracy: 0.8275568181818181 -- macro_f1: 0.8270568913757921 -- loss: 0.290753036737442
|
31 |
+
|
32 |
+
Step 1500 -- Accuracy: 0.8619318181818182 -- macro_f1: 0.8620216901652552 -- loss: 0.17760200798511505
|
33 |
+
|
34 |
+
Step 1600 -- Accuracy: 0.8366477272727273 -- macro_f1: 0.8372501215741125 -- loss: 0.18745465576648712
|
35 |
+
|
36 |
+
Step 1700 -- Accuracy: 0.8556818181818182 -- macro_f1: 0.8555692365839257 -- loss: 0.09077112376689911
|
37 |
+
|
38 |
+
Step 1800 -- Accuracy: 0.8571022727272727 -- macro_f1: 0.8569408344903815 -- loss: 0.24079212546348572
|
39 |
+
|
40 |
+
Step 1900 -- Accuracy: 0.8122159090909091 -- macro_f1: 0.8117034674801616 -- loss: 0.3681311309337616
|
41 |
+
|
42 |
+
Step 2000 -- Accuracy: 0.8318181818181818 -- macro_f1: 0.8319676688379705 -- loss: 0.2374744713306427
|
43 |
+
|
44 |
+
Step 2100 -- Accuracy: 0.8443181818181819 -- macro_f1: 0.8442918629955193 -- loss: 0.4600515365600586
|
45 |
+
|
46 |
+
Step 2200 -- Accuracy: 0.8278409090909091 -- macro_f1: 0.8269904995679983 -- loss: 0.3283902704715729
|
47 |
+
|
48 |
+
Step 2300 -- Accuracy: 0.8298295454545455 -- macro_f1: 0.8299882032010862 -- loss: 1.0965081453323364
|
49 |
+
|
50 |
+
Step 2400 -- Accuracy: 0.8159090909090909 -- macro_f1: 0.8159808860940237 -- loss: 0.7295159697532654
|
51 |
+
|
52 |
+
Step 2500 -- Accuracy: 0.8159090909090909 -- macro_f1: 0.8142475187664063 -- loss: 0.3925968408584595
|
53 |
+
|
54 |
+
Step 2600 -- Accuracy: 0.8204545454545454 -- macro_f1: 0.820545798600696 -- loss: 0.3808274567127228
|
55 |
+
|
56 |
+
Step 2700 -- Accuracy: 0.8198863636363637 -- macro_f1: 0.8199413434559383 -- loss: 0.26008090376853943
|
57 |
+
|
58 |
+
Step 2800 -- Accuracy: 0.8056818181818182 -- macro_f1: 0.8051566431375038 -- loss: 0.20567485690116882
|
59 |
+
|
60 |
+
Step 2900 -- Accuracy: 0.784375 -- macro_f1: 0.7848921849530183 -- loss: 0.5506788492202759
|
61 |
+
|
62 |
+
Step 3000 -- Accuracy: 0.8153409090909091 -- macro_f1: 0.8150634367874668 -- loss: 0.4250873923301697
|
63 |
+
|
64 |
+
Step 3100 -- Accuracy: 0.7991477272727273 -- macro_f1: 0.8000715520252392 -- loss: 0.4798588752746582
|
65 |
+
|
66 |
+
Step 3200 -- Accuracy: 0.7840909090909091 -- macro_f1: 0.7836356305606565 -- loss: 0.5604580640792847
|
67 |
+
|
68 |
+
Step 3300 -- Accuracy: 0.7977272727272727 -- macro_f1: 0.7965403402362528 -- loss: 0.26682722568511963
|
69 |
+
|
70 |
+
Step 3400 -- Accuracy: 0.809375 -- macro_f1: 0.8087947373143304 -- loss: 0.3252097964286804
|
71 |
+
|
72 |
+
Step 3500 -- Accuracy: 0.7568181818181818 -- macro_f1: 0.7548780108676749 -- loss: 0.9467527866363525
|
73 |
+
|
74 |
+
Step 3600 -- Accuracy: 0.7889204545454546 -- macro_f1: 0.7892382882596812 -- loss: 0.29441171884536743
|
75 |
+
|
76 |
+
Step 3700 -- Accuracy: 0.7227272727272728 -- macro_f1: 0.7227876418017654 -- loss: 0.8389160633087158
|
src/mDeBERTa (ft) V6/plot.png
ADDED
src/mDeBERTa (ft) V6/public_train_v4.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56c03b9bb2cab8ffbe138badea76b6275ebad727e99f5040d2a8c21f2dcfaff2
|
3 |
+
size 227113690
|
src/myNLI.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
from sentence_transformers import SentenceTransformer, util
|
4 |
+
import nltk
|
5 |
+
|
6 |
+
# import datasets
|
7 |
+
from datasets import Dataset, DatasetDict
|
8 |
+
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
from .utils import timer_func
|
12 |
+
from .nli_v3 import NLI_model
|
13 |
+
from .crawler import MyCrawler
|
14 |
+
|
15 |
+
int2label = {0:'SUPPORTED', 1:'NEI', 2:'REFUTED'}
|
16 |
+
|
17 |
+
class FactChecker:
|
18 |
+
|
19 |
+
@timer_func
|
20 |
+
def __init__(self):
|
21 |
+
self.INPUT_TYPE = "mean"
|
22 |
+
self.load_model()
|
23 |
+
|
24 |
+
@timer_func
|
25 |
+
def load_model(self):
|
26 |
+
self.envir = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
27 |
+
|
28 |
+
# Load LLM
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") # LOAD mDEBERTa TOKENIZER
|
30 |
+
self.mDeBertaModel = AutoModel.from_pretrained(f"src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-{self.INPUT_TYPE}") # LOAD FINETUNED MODEL
|
31 |
+
# Load classifier model
|
32 |
+
self.checkpoints = torch.load(f"src/mDeBERTa (ft) V6/{self.INPUT_TYPE}.pt", map_location=self.envir)
|
33 |
+
|
34 |
+
self.classifierModel = NLI_model(768, torch.tensor([0., 0., 0.])).to(self.envir)
|
35 |
+
self.classifierModel.load_state_dict(self.checkpoints['model_state_dict'])
|
36 |
+
|
37 |
+
#Load model for predict similarity
|
38 |
+
self.model_sbert = SentenceTransformer('keepitreal/vietnamese-sbert')
|
39 |
+
|
40 |
+
@timer_func
|
41 |
+
def get_similarity_v2(self, src_sents, dst_sents, threshold = 0.4):
|
42 |
+
corpus_embeddings = self.model_sbert.encode(dst_sents, convert_to_tensor=True)
|
43 |
+
top_k = min(5, len(dst_sents))
|
44 |
+
ls_top_results = []
|
45 |
+
for query in src_sents:
|
46 |
+
query_embedding = self.model_sbert.encode(query, convert_to_tensor=True)
|
47 |
+
# We use cosine-similarity and torch.topk to find the highest 5 scores
|
48 |
+
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
|
49 |
+
top_results = torch.topk(cos_scores, k=top_k)
|
50 |
+
|
51 |
+
# print("\n\n======================\n\n")
|
52 |
+
# print("Query:", src_sents)
|
53 |
+
# print("\nTop 5 most similar sentences in corpus:")
|
54 |
+
ls_top_results.append({
|
55 |
+
"top_k": top_k,
|
56 |
+
"claim": query,
|
57 |
+
"sim_score": top_results,
|
58 |
+
"evidences": [dst_sents[idx] for _, idx in zip(top_results[0], top_results[1])],
|
59 |
+
})
|
60 |
+
|
61 |
+
# for score, idx in zip(top_results[0], top_results[1]):
|
62 |
+
# print(dst_sents[idx], "(Score: {:.4f})".format(score))
|
63 |
+
return None,ls_top_results
|
64 |
+
|
65 |
+
@timer_func
|
66 |
+
def inferSample(self, evidence, claim):
|
67 |
+
|
68 |
+
@timer_func
|
69 |
+
def mDeBERTa_tokenize(data): # mDeBERTa model: Taking input_ids
|
70 |
+
premises = [premise for premise, _ in data['sample']]
|
71 |
+
hypothesis = [hypothesis for _, hypothesis in data['sample']]
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
input_token = (self.tokenizer(premises, hypothesis, truncation=True, return_tensors="pt", padding = True)['input_ids']).to(self.envir)
|
75 |
+
embedding = self.mDeBertaModel(input_token).last_hidden_state
|
76 |
+
|
77 |
+
mean_embedding = torch.mean(embedding[:, 1:, :], dim = 1)
|
78 |
+
cls_embedding = embedding[:, 0, :]
|
79 |
+
|
80 |
+
return {'mean':mean_embedding, 'cls':cls_embedding}
|
81 |
+
|
82 |
+
@timer_func
|
83 |
+
def predict_mapping(batch):
|
84 |
+
with torch.no_grad():
|
85 |
+
predict_label, predict_prob = self.classifierModel.predict_step((batch[self.INPUT_TYPE].to(self.envir), None))
|
86 |
+
return {'label':predict_label, 'prob':-predict_prob}
|
87 |
+
|
88 |
+
# Mapping the predict label into corresponding string labels
|
89 |
+
@timer_func
|
90 |
+
def output_predictedDataset(predict_dataset):
|
91 |
+
for record in predict_dataset:
|
92 |
+
labels = int2label[ record['label'].item() ]
|
93 |
+
confidence = record['prob'].item()
|
94 |
+
|
95 |
+
return {'labels':labels, 'confidence':confidence}
|
96 |
+
|
97 |
+
dataset = {'sample':[(evidence, claim)], 'key': [0]}
|
98 |
+
output_dataset = DatasetDict({
|
99 |
+
'infer': Dataset.from_dict(dataset)
|
100 |
+
})
|
101 |
+
|
102 |
+
@timer_func
|
103 |
+
def tokenize_dataset():
|
104 |
+
|
105 |
+
tokenized_dataset = output_dataset.map(mDeBERTa_tokenize, batched=True, batch_size=1)
|
106 |
+
return tokenized_dataset
|
107 |
+
|
108 |
+
tokenized_dataset = tokenize_dataset()
|
109 |
+
tokenized_dataset = tokenized_dataset.with_format("torch", [self.INPUT_TYPE, 'key'])
|
110 |
+
# Running inference step
|
111 |
+
predicted_dataset = tokenized_dataset.map(predict_mapping, batched=True, batch_size=tokenized_dataset['infer'].num_rows)
|
112 |
+
return output_predictedDataset(predicted_dataset['infer'])
|
113 |
+
|
114 |
+
@timer_func
|
115 |
+
def predict_vt(self, claim: str) -> List:
|
116 |
+
# import pdb; pdb.set_trace()
|
117 |
+
# step 1: crawl evidences from bing search
|
118 |
+
crawler = MyCrawler()
|
119 |
+
evidences = crawler.searchGoogle(claim)
|
120 |
+
|
121 |
+
# evidences = crawler.get_evidences(claim)
|
122 |
+
# step 2: use emebdding setences to search most related setences
|
123 |
+
if len(evidences) == 0:
|
124 |
+
return None
|
125 |
+
|
126 |
+
for evidence in evidences:
|
127 |
+
print(evidence['url'])
|
128 |
+
top_evidence = evidence["content"]
|
129 |
+
|
130 |
+
post_message = nltk.tokenize.sent_tokenize(claim)
|
131 |
+
evidences = nltk.tokenize.sent_tokenize(top_evidence)
|
132 |
+
_, top_rst = self.get_similarity_v2(post_message, evidences)
|
133 |
+
|
134 |
+
print(top_rst)
|
135 |
+
|
136 |
+
ls_evidence, final_verdict = self.get_result_nli_v2(top_rst)
|
137 |
+
|
138 |
+
print("FINAL: " + final_verdict)
|
139 |
+
# _, top_rst = self.get_similarity_v1(post_message, evidences)
|
140 |
+
# ls_evidence, final_verdict = self.get_result_nli_v1(post_message, top_rst, evidences)
|
141 |
+
return ls_evidence, final_verdict
|
142 |
+
|
143 |
+
|
144 |
+
@timer_func
|
145 |
+
def predict(self, claim):
|
146 |
+
crawler = MyCrawler()
|
147 |
+
evidences = crawler.searchGoogle(claim)
|
148 |
+
|
149 |
+
if evidences:
|
150 |
+
tokenized_claim = nltk.tokenize.sent_tokenize(claim)
|
151 |
+
evidence = evidences[0]
|
152 |
+
tokenized_evidence = nltk.tokenize.sent_tokenize(evidence["content"])
|
153 |
+
# print("TOKENIZED EVIDENCES")
|
154 |
+
# print(tokenized_evidence)
|
155 |
+
_, top_rst = self.get_similarity_v2(tokenized_claim, tokenized_evidence)
|
156 |
+
|
157 |
+
processed_evidence = "\n".join(top_rst[0]["evidences"])
|
158 |
+
print(processed_evidence)
|
159 |
+
|
160 |
+
nli_result = self.inferSample(processed_evidence, claim)
|
161 |
+
return {
|
162 |
+
"claim": claim,
|
163 |
+
"label": nli_result["labels"],
|
164 |
+
"confidence": nli_result['confidence'],
|
165 |
+
"evidence": processed_evidence if nli_result["labels"] != "NEI" else "",
|
166 |
+
"provider": evidence['provider'],
|
167 |
+
"url": evidence['url']
|
168 |
+
}
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
@timer_func
|
173 |
+
def predict_nofilter(self, claim):
|
174 |
+
crawler = MyCrawler()
|
175 |
+
evidences = crawler.searchGoogle(claim)
|
176 |
+
tokenized_claim = nltk.tokenize.sent_tokenize(claim)
|
177 |
+
|
178 |
+
evidence = evidences[0]
|
179 |
+
|
180 |
+
processed_evidence = evidence['content']
|
181 |
+
|
182 |
+
nli_result = self.inferSample(processed_evidence, claim)
|
183 |
+
return {
|
184 |
+
"claim": claim,
|
185 |
+
"label": nli_result["labels"],
|
186 |
+
"confidence": nli_result['confidence'],
|
187 |
+
"evidence": processed_evidence if nli_result["labels"] != "NEI" else "",
|
188 |
+
"provider": evidence['provider'],
|
189 |
+
"url": evidence['url']
|
190 |
+
}
|
src/nli_v3.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from transformers import AutoModel, AutoTokenizer
|
6 |
+
|
7 |
+
# import datasets
|
8 |
+
from datasets import Dataset, DatasetDict
|
9 |
+
|
10 |
+
from sklearn.metrics import classification_report
|
11 |
+
from sklearn.metrics._classification import _check_targets
|
12 |
+
|
13 |
+
envir = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
14 |
+
|
15 |
+
int2label = {0:'SUPPORTED', 1:'NEI', 2:'REFUTED'}
|
16 |
+
|
17 |
+
class NLI_model(nn.Module):
|
18 |
+
def __init__(self, input_dims, class_weights=torch.tensor([0., 0., 0.])):
|
19 |
+
super(NLI_model, self).__init__()
|
20 |
+
|
21 |
+
self.classification = nn.Sequential(
|
22 |
+
nn.Linear(input_dims, 3)
|
23 |
+
)
|
24 |
+
|
25 |
+
self.criterion = nn.CrossEntropyLoss(class_weights)
|
26 |
+
|
27 |
+
def forward(self, input):
|
28 |
+
output_linear = self.classification(input)
|
29 |
+
return output_linear
|
30 |
+
|
31 |
+
def training_step(self, train_batch, batch_idx=0):
|
32 |
+
input_data, targets = train_batch
|
33 |
+
outputs = self.forward(input_data)
|
34 |
+
loss = self.criterion(outputs, targets)
|
35 |
+
return loss
|
36 |
+
|
37 |
+
def predict_step(self, batch, batch_idx=0):
|
38 |
+
input_data, _ = batch
|
39 |
+
outputs = self.forward(input_data)
|
40 |
+
prob = outputs.softmax(dim = -1)
|
41 |
+
sort_prob, sort_indices = torch.sort(-prob, 1)
|
42 |
+
return sort_indices[:,0], sort_prob[:,0]
|
43 |
+
|
44 |
+
def validation_step(self, val_batch, batch_idx=0):
|
45 |
+
_, targets = val_batch
|
46 |
+
sort_indices, _ = self.predict_step(val_batch, batch_idx)
|
47 |
+
report = classification_report(list(targets.to('cpu').numpy()), list(sort_indices.to('cpu').numpy()), output_dict=True, zero_division = 1)
|
48 |
+
return report
|
49 |
+
|
50 |
+
def test_step(self, batch, dict_form, batch_idx=0):
|
51 |
+
_, targets = batch
|
52 |
+
sort_indices, _ = self.predict_step(batch, batch_idx)
|
53 |
+
report = classification_report(targets.to('cpu').numpy(), sort_indices.to('cpu').numpy(), output_dict=dict_form, zero_division = 1)
|
54 |
+
return report
|
55 |
+
|
56 |
+
def configure_optimizers(self):
|
57 |
+
return torch.optim.Adam(self.parameters(), lr = 1e-5)
|
58 |
+
|
59 |
+
|
60 |
+
def inferSample(evidence, claim, tokenizer, mDeBertaModel, classifierModel, input_type):
|
61 |
+
|
62 |
+
def mDeBERTa_tokenize(data): # mDeBERTa model: Taking input_ids
|
63 |
+
premises = [premise for premise, _ in data['sample']]
|
64 |
+
hypothesis = [hypothesis for _, hypothesis in data['sample']]
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
input_token = (tokenizer(premises, hypothesis, truncation=True, return_tensors="pt", padding = True)['input_ids']).to(envir)
|
68 |
+
embedding = mDeBertaModel(input_token).last_hidden_state
|
69 |
+
|
70 |
+
mean_embedding = torch.mean(embedding[:, 1:, :], dim = 1)
|
71 |
+
cls_embedding = embedding[:, 0, :]
|
72 |
+
|
73 |
+
return {'mean':mean_embedding, 'cls':cls_embedding}
|
74 |
+
|
75 |
+
def predict_mapping(batch):
|
76 |
+
with torch.no_grad():
|
77 |
+
predict_label, predict_prob = classifierModel.predict_step((batch[input_type].to(envir), None))
|
78 |
+
return {'label':predict_label, 'prob':-predict_prob}
|
79 |
+
|
80 |
+
# Mapping the predict label into corresponding string labels
|
81 |
+
def output_predictedDataset(predict_dataset):
|
82 |
+
for record in predict_dataset:
|
83 |
+
labels = int2label[ record['label'].item() ]
|
84 |
+
confidence = record['prob'].item()
|
85 |
+
|
86 |
+
return {'labels':labels, 'confidence':confidence}
|
87 |
+
|
88 |
+
dataset = {'sample':[(evidence, claim)], 'key': [0]}
|
89 |
+
|
90 |
+
output_dataset = DatasetDict({
|
91 |
+
'infer': Dataset.from_dict(dataset)
|
92 |
+
})
|
93 |
+
|
94 |
+
tokenized_dataset = output_dataset.map(mDeBERTa_tokenize, batched=True, batch_size=1)
|
95 |
+
tokenized_dataset = tokenized_dataset.with_format("torch", [input_type, 'key'])
|
96 |
+
|
97 |
+
# Running inference step
|
98 |
+
predicted_dataset = tokenized_dataset.map(predict_mapping, batched=True, batch_size=tokenized_dataset['infer'].num_rows)
|
99 |
+
return output_predictedDataset(predicted_dataset['infer'])
|
100 |
+
|
101 |
+
if __name__ == '__main__':
|
102 |
+
# CHANGE 'INPUT_TYPE' TO CHANGE MODEL
|
103 |
+
INPUT_TYPE = 'mean' # USE "MEAN" OR "CLS" LAST HIDDEN STATE
|
104 |
+
|
105 |
+
# Load LLM
|
106 |
+
tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") # LOAD mDEBERTa TOKENIZER
|
107 |
+
mDeBertaModel = AutoModel.from_pretrained(f"src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-{INPUT_TYPE}") # LOAD FINETUNED MODEL
|
108 |
+
# Load classifier model
|
109 |
+
checkpoints = torch.load(f"src/mDeBERTa (ft) V6/{INPUT_TYPE}.pt", map_location=envir)
|
110 |
+
classifierModel = NLI_model(768, torch.tensor([0., 0., 0.])).to(envir)
|
111 |
+
classifierModel.load_state_dict(checkpoints['model_state_dict'])
|
112 |
+
|
113 |
+
evidence = "Sau khi thẩm định, Liên đoàn Bóng đá châu Á AFC xác nhận thủ thành mới nhập quốc tịch của Việt Nam Filip Nguyễn đủ điều kiện thi đấu ở Asian Cup 2024."
|
114 |
+
claim = "Filip Nguyễn đủ điều kiện dự Asian Cup 2024"
|
115 |
+
print(inferSample(evidence, claim, tokenizer, mDeBertaModel, classifierModel, INPUT_TYPE))
|
src/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import time
|
2 |
+
|
3 |
+
def timer_func(func):
|
4 |
+
# This function shows the execution time of
|
5 |
+
# the function object passed
|
6 |
+
def wrap_func(*args, **kwargs):
|
7 |
+
t1 = time()
|
8 |
+
result = func(*args, **kwargs)
|
9 |
+
t2 = time()
|
10 |
+
print(f'Function {func.__name__!r} executed in {(t2-t1):.4f}s')
|
11 |
+
return result
|
12 |
+
return wrap_func
|