Spaces:
Runtime error
Runtime error
from pyvis.network import Network | |
from GoogleNews import GoogleNews | |
from newspaper import Article, ArticleException | |
import math | |
import torch | |
from kb import KB | |
import pickle | |
def extract_relations_from_model_output(text): | |
relations = [] | |
relation, subject, relation, object_ = '', '', '', '' | |
text = text.strip() | |
current = 'x' | |
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "") | |
for token in text_replaced.split(): | |
if token == "<triplet>": | |
current = 't' | |
if relation != '': | |
relations.append({ | |
'head': subject.strip(), | |
'type': relation.strip(), | |
'tail': object_.strip() | |
}) | |
relation = '' | |
subject = '' | |
elif token == "<subj>": | |
current = 's' | |
if relation != '': | |
relations.append({ | |
'head': subject.strip(), | |
'type': relation.strip(), | |
'tail': object_.strip() | |
}) | |
object_ = '' | |
elif token == "<obj>": | |
current = 'o' | |
relation = '' | |
else: | |
if current == 't': | |
subject += ' ' + token | |
elif current == 's': | |
object_ += ' ' + token | |
elif current == 'o': | |
relation += ' ' + token | |
if subject != '' and relation != '' and object_ != '': | |
relations.append({ | |
'head': subject.strip(), | |
'type': relation.strip(), | |
'tail': object_.strip() | |
}) | |
return relations | |
def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None, | |
article_publish_date=None, verbose=False): | |
# tokenize whole text | |
inputs = tokenizer([text], return_tensors="pt") | |
# compute span boundaries | |
num_tokens = len(inputs["input_ids"][0]) | |
if verbose: | |
print(f"Input has {num_tokens} tokens") | |
num_spans = math.ceil(num_tokens / span_length) | |
if verbose: | |
print(f"Input has {num_spans} spans") | |
overlap = math.ceil((num_spans * span_length - num_tokens) / | |
max(num_spans - 1, 1)) | |
spans_boundaries = [] | |
start = 0 | |
for i in range(num_spans): | |
spans_boundaries.append([start + span_length * i, | |
start + span_length * (i + 1)]) | |
start -= overlap | |
if verbose: | |
print(f"Span boundaries are {spans_boundaries}") | |
# transform input with spans | |
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] | |
for boundary in spans_boundaries] | |
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] | |
for boundary in spans_boundaries] | |
inputs = { | |
"input_ids": torch.stack(tensor_ids), | |
"attention_mask": torch.stack(tensor_masks) | |
} | |
# generate relations | |
num_return_sequences = 3 | |
gen_kwargs = { | |
"max_length": 256, | |
"length_penalty": 0, | |
"num_beams": 3, | |
"num_return_sequences": num_return_sequences | |
} | |
generated_tokens = model.generate( | |
**inputs, | |
**gen_kwargs, | |
) | |
# decode relations | |
decoded_preds = tokenizer.batch_decode(generated_tokens, | |
skip_special_tokens=False) | |
# create kb | |
kb = KB() | |
i = 0 | |
for sentence_pred in decoded_preds: | |
current_span_index = i // num_return_sequences | |
relations = extract_relations_from_model_output(sentence_pred) | |
for relation in relations: | |
relation["meta"] = { | |
article_url: { | |
"spans": [spans_boundaries[current_span_index]] | |
} | |
} | |
kb.add_relation(relation, article_title, article_publish_date) | |
i += 1 | |
return kb | |
def get_article(url): | |
article = Article(url) | |
article.download() | |
article.parse() | |
return article | |
def from_url_to_kb(url, model, tokenizer): | |
article = get_article(url) | |
config = { | |
"article_title": article.title, | |
"article_publish_date": article.publish_date | |
} | |
kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config) | |
return kb | |
def get_news_links(query, lang="en", region="US", pages=1): | |
googlenews = GoogleNews(lang=lang, region=region) | |
googlenews.search(query) | |
all_urls = [] | |
for page in range(pages): | |
googlenews.get_page(page) | |
all_urls += googlenews.get_links() | |
return list(set(all_urls)) | |
def from_urls_to_kb(urls, model, tokenizer, verbose=False): | |
kb = KB() | |
if verbose: | |
print(f"{len(urls)} links to visit") | |
for url in urls: | |
if verbose: | |
print(f"Visiting {url}...") | |
try: | |
kb_url = from_url_to_kb(url, model, tokenizer) | |
kb.merge_with_kb(kb_url) | |
except ArticleException: | |
if verbose: | |
print(f" Couldn't download article at url {url}") | |
return kb | |
def save_network_html(kb, filename="network.html"): | |
# create network | |
net = Network(directed=True, width="700px", height="700px") | |
# nodes | |
color_entity = "#00FF00" | |
for e in kb.entities: | |
net.add_node(e, shape="circle", color=color_entity) | |
# edges | |
for r in kb.relations: | |
net.add_edge(r["head"], r["tail"], | |
title=r["type"], label=r["type"]) | |
# save network | |
net.repulsion( | |
node_distance=200, | |
central_gravity=0.2, | |
spring_length=200, | |
spring_strength=0.05, | |
damping=0.09 | |
) | |
net.set_edge_smooth('dynamic') | |
net.show(filename) | |
def save_kb(kb, filename): | |
with open(filename, "wb") as f: | |
pickle.dump(kb, f) | |
class CustomUnpickler(pickle.Unpickler): | |
def find_class(self, module, name): | |
if name == 'KB': | |
return KB | |
return super().find_class(module, name) | |
def load_kb(filename): | |
res = None | |
with open(filename, "rb") as f: | |
res = CustomUnpickler(f).load() | |
return res | |