Commit
·
70e394e
1
Parent(s):
2fcb9be
fixed missing attention mask code
Browse files
README.md
CHANGED
@@ -69,16 +69,13 @@ class SentimentModel():
|
|
69 |
def predict_sentiment(self, texts: List[str])-> List[str]:
|
70 |
texts = [self.clean_text(text) for text in texts]
|
71 |
# Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
with torch.no_grad():
|
76 |
-
|
77 |
-
|
78 |
label_ids = torch.argmax(logits[0], axis=1)
|
79 |
-
|
80 |
-
labels = [self.model.config.id2label[label_id] for label_id in label_ids.tolist()]
|
81 |
-
return labels
|
82 |
|
83 |
def replace_numbers(self,text: str) -> str:
|
84 |
return text.replace("0"," null").replace("1"," eins").replace("2"," zwei").replace("3"," drei").replace("4"," vier").replace("5"," fünf").replace("6"," sechs").replace("7"," sieben").replace("8"," acht").replace("9"," neun")
|
|
|
69 |
def predict_sentiment(self, texts: List[str])-> List[str]:
|
70 |
texts = [self.clean_text(text) for text in texts]
|
71 |
# Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
|
72 |
+
encoded = self.tokenizer.batch_encode_plus(texts,padding=True, add_special_tokens=True,truncation=True, return_tensors="pt")
|
73 |
+
encoded = encoded.to(self.device)
|
|
|
74 |
with torch.no_grad():
|
75 |
+
logits = self.model(**encoded)
|
76 |
+
|
77 |
label_ids = torch.argmax(logits[0], axis=1)
|
78 |
+
return [self.model.config.id2label[label_id.item()] for label_id in label_ids]
|
|
|
|
|
79 |
|
80 |
def replace_numbers(self,text: str) -> str:
|
81 |
return text.replace("0"," null").replace("1"," eins").replace("2"," zwei").replace("3"," drei").replace("4"," vier").replace("5"," fünf").replace("6"," sechs").replace("7"," sieben").replace("8"," acht").replace("9"," neun")
|