File size: 810 Bytes
30cd0bc ee8ab93 30cd0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from transformers import MLukeTokenizer
from torch import nn
tokenizer = MLukeTokenizer.from_pretrained('studio-ousia/luke-japanese-base-lite')
model = torch.load('C:\\[modelのあるディレクトリ]\\My_luke_model_pn.pth')
text=input()
encoded_dict = tokenizer.encode_plus(
text,
return_attention_mask = True, # Attention maksの作成
return_tensors = 'pt', # Pytorch tensorsで返す
)
pre = model(encoded_dict['input_ids'], token_type_ids=None, attention_mask=encoded_dict['attention_mask'])
SOFTMAX=nn.Softmax(dim=0)
num=SOFTMAX(pre.logits[0])
if num[1]>0.5:
print(str(num[1]))
print('ポジティブ')
else:
print(str(num[1]))
print('ネガティブ') |