|
from typing import Dict, List, Any |
|
import numpy as np |
|
import pickle |
|
|
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
from transformers import AutoTokenizer |
|
import torch |
|
|
|
from eurovoc import EurovocTagger |
|
|
|
BERT_MODEL_NAME = "nlpaueb/legal-bert-base-uncased" |
|
MAX_LEN = 512 |
|
TEXT_MAX_LEN = MAX_LEN * 50 |
|
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME) |
|
|
|
|
|
class EndpointHandler: |
|
mlb = MultiLabelBinarizer() |
|
|
|
def __init__(self, path=""): |
|
self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb")) |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = EurovocTagger.from_pretrained(path, |
|
bert_model_name=BERT_MODEL_NAME, |
|
n_classes=len(self.mlb.classes_), |
|
map_location=self.device) |
|
self.model.eval() |
|
self.model.freeze() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
text = data.pop("inputs", data) |
|
topk = data.pop("topk", 5) |
|
threshold = data.pop("threshold", 0.16) |
|
debug = data.pop("debug", False) |
|
prediction = self.get_prediction(text) |
|
results = [{"label": label, "score": float(score)} for label, score in |
|
zip(self.mlb.classes_, prediction[0].tolist())] |
|
results = sorted(results, key=lambda x: x["score"], reverse=True) |
|
results = [r for r in results if r["score"] > threshold] |
|
results = results[:topk] |
|
if debug: |
|
return {"results": results, "values": prediction, "input": text} |
|
else: |
|
return {"results": results} |
|
|
|
def get_prediction(self, text): |
|
|
|
chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)] |
|
predictions = [self._get_prediction(chunk) for chunk in chunks] |
|
predictions = np.array(predictions).mean(axis=0) |
|
return predictions |
|
|
|
def _get_prediction(self, text): |
|
item = tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=MAX_LEN, |
|
return_token_type_ids=False, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt') |
|
_, prediction = self.model(item["input_ids"], item["attention_mask"]) |
|
prediction = prediction.cpu().detach().numpy() |
|
return prediction |
|
|
|
|