Spaces:
Runtime error
Runtime error
dragonSwing
commited on
Commit
•
07eef75
1
Parent(s):
dd9b3ed
Fix error_prob bug
Browse files- gec_model.py +5 -1
gec_model.py
CHANGED
@@ -89,6 +89,7 @@ class GecBERTModel(torch.nn.Module):
|
|
89 |
self.lowercase_tokens = lowercase_tokens
|
90 |
self.min_error_probability = min_error_probability
|
91 |
self.vocab = Vocabulary.from_files(vocab_path)
|
|
|
92 |
self.log = log
|
93 |
self.iterations = iterations
|
94 |
self.confidence = confidence
|
@@ -337,7 +338,10 @@ class GecBERTModel(torch.nn.Module):
|
|
337 |
for output, weight in zip(data, self.model_weights):
|
338 |
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
339 |
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
340 |
-
|
|
|
|
|
|
|
341 |
|
342 |
max_vals = torch.max(all_class_probs, dim=-1)
|
343 |
probs = max_vals[0].tolist()
|
|
|
89 |
self.lowercase_tokens = lowercase_tokens
|
90 |
self.min_error_probability = min_error_probability
|
91 |
self.vocab = Vocabulary.from_files(vocab_path)
|
92 |
+
self.incorr_index = self.vocab.get_token_index("INCORRECT", "d_tags")
|
93 |
self.log = log
|
94 |
self.iterations = iterations
|
95 |
self.confidence = confidence
|
|
|
338 |
for output, weight in zip(data, self.model_weights):
|
339 |
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
340 |
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
341 |
+
class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
|
342 |
+
error_probs_d = class_probabilities_d[:, :, self.incorr_index]
|
343 |
+
incorr_prob = torch.max(error_probs_d, dim=-1)[0]
|
344 |
+
error_probs += weight * incorr_prob / sum(self.model_weights)
|
345 |
|
346 |
max_vals = torch.max(all_class_probs, dim=-1)
|
347 |
probs = max_vals[0].tolist()
|