Commit
•
67a58db
1
Parent(s):
aee0221
add handler
Browse files- __pycache__/handler.cpython-310.pyc +0 -0
- data/verb-form-vocab.txt +0 -0
- gector/__init__.py +22 -0
- gector/__pycache__/__init__.cpython-310.pyc +0 -0
- gector/__pycache__/configuration.cpython-310.pyc +0 -0
- gector/__pycache__/dataset.cpython-310.pyc +0 -0
- gector/__pycache__/modeling.cpython-310.pyc +0 -0
- gector/__pycache__/predict.cpython-310.pyc +0 -0
- gector/__pycache__/predict_verbose.cpython-310.pyc +0 -0
- gector/__pycache__/vocab.cpython-310.pyc +0 -0
- gector/configuration.py +38 -0
- gector/dataset.py +164 -0
- gector/modeling.py +200 -0
- gector/predict.py +232 -0
- gector/predict_verbose.py +83 -0
- gector/vocab.py +48 -0
- handler.py +45 -0
- requirements.txt +27 -0
__pycache__/handler.cpython-310.pyc
ADDED
Binary file (2.17 kB). View file
|
|
data/verb-form-vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gector/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .modeling import GECToR
|
2 |
+
from .configuration import GECToRConfig
|
3 |
+
from .dataset import load_dataset, GECToRDataset
|
4 |
+
from .predict import predict, load_verb_dict
|
5 |
+
from .predict_verbose import predict_verbose
|
6 |
+
from .vocab import (
|
7 |
+
build_vocab,
|
8 |
+
load_vocab_from_config,
|
9 |
+
load_vocab_from_official
|
10 |
+
)
|
11 |
+
__all__ = [
|
12 |
+
'GECToR',
|
13 |
+
'GECToRConfig',
|
14 |
+
'load_dataset',
|
15 |
+
'GECToRDataset',
|
16 |
+
'predict',
|
17 |
+
'load_verb_dict',
|
18 |
+
'predict_verbose',
|
19 |
+
'build_vocab',
|
20 |
+
'load_vocab_from_config',
|
21 |
+
'load_vocab_from_official'
|
22 |
+
]
|
gector/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (626 Bytes). View file
|
|
gector/__pycache__/configuration.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
gector/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (5.04 kB). View file
|
|
gector/__pycache__/modeling.cpython-310.pyc
ADDED
Binary file (5.49 kB). View file
|
|
gector/__pycache__/predict.cpython-310.pyc
ADDED
Binary file (5.35 kB). View file
|
|
gector/__pycache__/predict_verbose.cpython-310.pyc
ADDED
Binary file (1.93 kB). View file
|
|
gector/__pycache__/vocab.cpython-310.pyc
ADDED
Binary file (2.23 kB). View file
|
|
gector/configuration.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from transformers import PretrainedConfig
|
4 |
+
class GECToRConfig(PretrainedConfig):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
model_id: str = 'bert-base-cased',
|
8 |
+
p_dropout: float=0,
|
9 |
+
label_pad_token: str='<PAD>',
|
10 |
+
label_oov_token: str='<OOV>',
|
11 |
+
d_pad_token: str='<PAD>',
|
12 |
+
keep_label: str='$KEEP',
|
13 |
+
correct_label: str='$CORRECT',
|
14 |
+
incorrect_label: str='$INCORRECT',
|
15 |
+
label_smoothing: float=0.0,
|
16 |
+
has_add_pooling_layer: bool=True,
|
17 |
+
initializer_range: float=0.02,
|
18 |
+
**kwards
|
19 |
+
):
|
20 |
+
super().__init__(**kwards)
|
21 |
+
self.d_label2id = {
|
22 |
+
"$CORRECT": 0,
|
23 |
+
"$INCORRECT": 1,
|
24 |
+
"<PAD>": 2
|
25 |
+
}
|
26 |
+
self.d_id2label = {v: k for k, v in self.d_label2id.items()}
|
27 |
+
self.d_num_labels = len(self.d_label2id)
|
28 |
+
self.model_id = model_id
|
29 |
+
self.p_dropout = p_dropout
|
30 |
+
self.label_pad_token = label_pad_token
|
31 |
+
self.label_oov_token = label_oov_token
|
32 |
+
self.d_pad_token = d_pad_token
|
33 |
+
self.keep_label = keep_label
|
34 |
+
self.correct_label = correct_label
|
35 |
+
self.incorrect_label = incorrect_label
|
36 |
+
self.label_smoothing = label_smoothing
|
37 |
+
self.has_add_pooling_layer = has_add_pooling_layer
|
38 |
+
self.initializer_range = initializer_range
|
gector/dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
from collections import Counter
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
import os
|
6 |
+
from transformers import PreTrainedTokenizer
|
7 |
+
|
8 |
+
class GECToRDataset:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
srcs: List[str],
|
12 |
+
d_labels: List[List[int]]=None,
|
13 |
+
labels: List[List[int]]=None,
|
14 |
+
word_masks: List[List[int]]=None,
|
15 |
+
tokenizer: PreTrainedTokenizer=None,
|
16 |
+
max_length:int=128
|
17 |
+
):
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
self.srcs = srcs
|
20 |
+
self.d_labels = d_labels
|
21 |
+
self.labels = labels
|
22 |
+
self.word_masks = word_masks
|
23 |
+
self.max_length = max_length
|
24 |
+
self.label2id = None
|
25 |
+
self.d_label2id = None
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.srcs)
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
src = self.srcs[idx]
|
32 |
+
d_labels = self.d_labels[idx]
|
33 |
+
labels = self.labels[idx]
|
34 |
+
wmask = self.word_masks[idx]
|
35 |
+
encode = self.tokenizer(
|
36 |
+
src,
|
37 |
+
return_tensors='pt',
|
38 |
+
max_length=self.max_length,
|
39 |
+
padding='max_length',
|
40 |
+
truncation=True,
|
41 |
+
is_split_into_words=True
|
42 |
+
)
|
43 |
+
return {
|
44 |
+
'input_ids': encode['input_ids'].squeeze(),
|
45 |
+
'attention_mask': encode['attention_mask'].squeeze(),
|
46 |
+
'd_labels': torch.tensor(d_labels).squeeze(),
|
47 |
+
'labels': torch.tensor(labels).squeeze(),
|
48 |
+
'word_masks': torch.tensor(wmask).squeeze()
|
49 |
+
}
|
50 |
+
|
51 |
+
def append_vocab(self, label2id, d_label2id):
|
52 |
+
self.label2id = label2id
|
53 |
+
self.d_label2id = d_label2id
|
54 |
+
for i in range(len(self.labels)):
|
55 |
+
self.labels[i] = [self.label2id.get(l, self.label2id['<OOV>']) for l in self.labels[i]]
|
56 |
+
self.d_labels[i] = [self.d_label2id[l] for l in self.d_labels[i]]
|
57 |
+
|
58 |
+
def get_labels_freq(self, exluded_labels: List[str] = []):
|
59 |
+
assert(self.labels is not None and self.d_labels is not None)
|
60 |
+
flatten_labels = [ll for l in self.labels for ll in l if ll not in exluded_labels]
|
61 |
+
flatten_d_labels = [ll for l in self.d_labels for ll in l if ll not in exluded_labels]
|
62 |
+
return Counter(flatten_labels), Counter(flatten_d_labels)
|
63 |
+
|
64 |
+
def align_labels_to_subwords(
|
65 |
+
srcs: List[str],
|
66 |
+
word_labels: List[List[str]],
|
67 |
+
tokenizer: PreTrainedTokenizer,
|
68 |
+
batch_size: int=100000,
|
69 |
+
max_length: int=128,
|
70 |
+
keep_label: str='$KEEP',
|
71 |
+
pad_token: str='<PAD>',
|
72 |
+
correct_label: str='$CORRECT',
|
73 |
+
incorrect_label: str='$INCORRECT'
|
74 |
+
):
|
75 |
+
itr = list(range(0, len(srcs), batch_size))
|
76 |
+
subword_labels = []
|
77 |
+
subword_d_labels = []
|
78 |
+
word_masks = []
|
79 |
+
for i in tqdm(itr):
|
80 |
+
encode = tokenizer(
|
81 |
+
srcs[i:i+batch_size],
|
82 |
+
max_length=max_length,
|
83 |
+
return_tensors='pt',
|
84 |
+
padding='max_length',
|
85 |
+
truncation=True,
|
86 |
+
is_split_into_words=True
|
87 |
+
)
|
88 |
+
for i, wlabels in enumerate(word_labels[i:i+batch_size]):
|
89 |
+
d_labels = []
|
90 |
+
labels = []
|
91 |
+
wmask = []
|
92 |
+
word_ids = encode.word_ids(i)
|
93 |
+
previous_word_idx = None
|
94 |
+
for word_idx in word_ids:
|
95 |
+
if word_idx is None:
|
96 |
+
labels.append(pad_token)
|
97 |
+
d_labels.append(pad_token)
|
98 |
+
wmask.append(0)
|
99 |
+
elif word_idx != previous_word_idx:
|
100 |
+
l = wlabels[word_idx]
|
101 |
+
labels.append(l)
|
102 |
+
wmask.append(1)
|
103 |
+
if l != keep_label:
|
104 |
+
d_labels.append(incorrect_label)
|
105 |
+
else:
|
106 |
+
d_labels.append(correct_label)
|
107 |
+
else:
|
108 |
+
labels.append(pad_token)
|
109 |
+
d_labels.append(pad_token)
|
110 |
+
wmask.append(0)
|
111 |
+
previous_word_idx = word_idx
|
112 |
+
subword_d_labels.append(d_labels)
|
113 |
+
subword_labels.append(labels)
|
114 |
+
word_masks.append(wmask)
|
115 |
+
return subword_d_labels, subword_labels, word_masks
|
116 |
+
|
117 |
+
def load_gector_format(
|
118 |
+
input_file: str,
|
119 |
+
delimeter: str='SEPL|||SEPR',
|
120 |
+
additional_delimeter: str='SEPL__SEPR'
|
121 |
+
):
|
122 |
+
srcs = []
|
123 |
+
word_level_labels = [] # the size will be (#sents, seq_length) if not get_interactive_tags,
|
124 |
+
# (#iteration, #sents, seq_length) if get_interactive_tags
|
125 |
+
with open(input_file) as f:
|
126 |
+
for line in f:
|
127 |
+
src = [x.split(delimeter)[0] for x in line.split()]
|
128 |
+
labels = [x.split(delimeter)[1] for x in line.split()]
|
129 |
+
# Use only first tags. E.g. $REPLACE_meSEPL__SEPR$APPEND_too → $REPLACE_me
|
130 |
+
labels = [l.split(additional_delimeter)[0] for l in labels]
|
131 |
+
srcs.append(src)
|
132 |
+
word_level_labels.append(labels)
|
133 |
+
return srcs, word_level_labels
|
134 |
+
|
135 |
+
def load_dataset(
|
136 |
+
input_file: str,
|
137 |
+
tokenizer: PreTrainedTokenizer,
|
138 |
+
delimeter: str='SEPL|||SEPR',
|
139 |
+
additional_delimeter: str='SEPL__SEPR',
|
140 |
+
batch_size: int=50000, # avoid too heavy computation in the tokenization
|
141 |
+
max_length: int=128
|
142 |
+
):
|
143 |
+
srcs, word_level_labels = load_gector_format(
|
144 |
+
input_file,
|
145 |
+
delimeter=delimeter,
|
146 |
+
additional_delimeter=additional_delimeter
|
147 |
+
)
|
148 |
+
d_labels, labels, word_masks = align_labels_to_subwords(
|
149 |
+
srcs,
|
150 |
+
word_level_labels,
|
151 |
+
tokenizer=tokenizer,
|
152 |
+
batch_size=batch_size,
|
153 |
+
max_length=max_length
|
154 |
+
)
|
155 |
+
return GECToRDataset(
|
156 |
+
srcs=srcs,
|
157 |
+
d_labels=d_labels,
|
158 |
+
labels=labels,
|
159 |
+
word_masks=word_masks,
|
160 |
+
tokenizer=tokenizer,
|
161 |
+
max_length=max_length
|
162 |
+
)
|
163 |
+
|
164 |
+
|
gector/modeling.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PreTrainedModel
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import CrossEntropyLoss
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from .configuration import GECToRConfig
|
8 |
+
from typing import List, Union, Optional, Tuple
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
from huggingface_hub import snapshot_download, ModelCard
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class GECToROutput:
|
15 |
+
loss: torch.Tensor = None
|
16 |
+
loss_d: torch.Tensor = None
|
17 |
+
loss_labels: torch.Tensor = None
|
18 |
+
logits_d: torch.Tensor = None
|
19 |
+
logits_labels: torch.Tensor = None
|
20 |
+
accuracy: torch.Tensor = None
|
21 |
+
accuracy_d: torch.Tensor = None
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class GECToRPredictionOutput:
|
25 |
+
probability_labels: torch.Tensor = None
|
26 |
+
probability_d: torch.Tensor = None
|
27 |
+
pred_labels: List[List[str]] = None
|
28 |
+
pred_label_ids: torch.Tensor = None
|
29 |
+
max_error_probability: torch.Tensor = None
|
30 |
+
|
31 |
+
class GECToR(PreTrainedModel):
|
32 |
+
config_class = GECToRConfig
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
config: GECToRConfig
|
36 |
+
):
|
37 |
+
super().__init__(config)
|
38 |
+
self.config = config
|
39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
40 |
+
self.config.model_id
|
41 |
+
)
|
42 |
+
if self.config.has_add_pooling_layer:
|
43 |
+
self.bert = AutoModel.from_pretrained(
|
44 |
+
self.config.model_id,
|
45 |
+
add_pooling_layer=False
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
self.bert = AutoModel.from_pretrained(
|
49 |
+
self.config.model_id
|
50 |
+
)
|
51 |
+
# +1 is for $START token
|
52 |
+
self.bert.resize_token_embeddings(self.bert.config.vocab_size + 1)
|
53 |
+
self.label_proj_layer = nn.Linear(
|
54 |
+
self.bert.config.hidden_size,
|
55 |
+
self.config.num_labels - 1
|
56 |
+
) # -1 is for <PAD>
|
57 |
+
self.d_proj_layer = nn.Linear(
|
58 |
+
self.bert.config.hidden_size,
|
59 |
+
self.config.d_num_labels - 1
|
60 |
+
)
|
61 |
+
self.dropout = nn.Dropout(self.config.p_dropout)
|
62 |
+
self.loss_fn = CrossEntropyLoss(
|
63 |
+
label_smoothing=self.config.label_smoothing
|
64 |
+
)
|
65 |
+
|
66 |
+
self.post_init()
|
67 |
+
self.tune_bert(False)
|
68 |
+
|
69 |
+
def init_weight(self) -> None:
|
70 |
+
self._init_weights(self.label_proj_layer)
|
71 |
+
self._init_weights(self.d_proj_layer)
|
72 |
+
|
73 |
+
def _init_weights(self, module) -> None:
|
74 |
+
"""Initialize the weights"""
|
75 |
+
if isinstance(module, nn.Linear):
|
76 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
77 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
78 |
+
module.weight.data.normal_(
|
79 |
+
mean=0.0,
|
80 |
+
std=self.config.initializer_range
|
81 |
+
)
|
82 |
+
if module.bias is not None:
|
83 |
+
module.bias.data.zero_()
|
84 |
+
return
|
85 |
+
|
86 |
+
def tune_bert(self, tune=True):
|
87 |
+
# If tune=False, only classifier layers will be tuned.
|
88 |
+
for param in self.bert.parameters():
|
89 |
+
param.requires_grad = tune
|
90 |
+
return
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
input_ids: Optional[torch.Tensor] = None,
|
95 |
+
attention_mask: Optional[torch.Tensor] = None,
|
96 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
97 |
+
position_ids: Optional[torch.Tensor] = None,
|
98 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
99 |
+
labels: Optional[torch.Tensor] = None,
|
100 |
+
d_labels: Optional[torch.Tensor] = None,
|
101 |
+
output_attentions: Optional[bool] = None,
|
102 |
+
output_hidden_states: Optional[bool] = None,
|
103 |
+
return_dict: Optional[bool] = None,
|
104 |
+
word_masks: Optional[torch.Tensor] = None,
|
105 |
+
) -> GECToROutput:
|
106 |
+
bert_logits = self.bert(
|
107 |
+
input_ids,
|
108 |
+
attention_mask=attention_mask,
|
109 |
+
token_type_ids=token_type_ids,
|
110 |
+
position_ids=position_ids,
|
111 |
+
inputs_embeds=inputs_embeds,
|
112 |
+
output_attentions=output_attentions,
|
113 |
+
output_hidden_states=output_hidden_states,
|
114 |
+
return_dict=return_dict,
|
115 |
+
).last_hidden_state
|
116 |
+
logits_d = self.d_proj_layer(bert_logits)
|
117 |
+
logits_labels = self.label_proj_layer(self.dropout(bert_logits))
|
118 |
+
loss_d, loss_labels, loss = None, None, None
|
119 |
+
accuracy, accuracy_d = None, None
|
120 |
+
if d_labels is not None and labels is not None:
|
121 |
+
pad_id = self.config.label2id[self.config.label_pad_token]
|
122 |
+
# -100 is the default ignore_idx of CrossEntropyLoss
|
123 |
+
labels[labels == pad_id] = -100
|
124 |
+
d_labels[labels == -100] = -100
|
125 |
+
loss_d = self.loss_fn(
|
126 |
+
logits_d.view(-1, self.config.d_num_labels - 1), # -1 for <PAD>
|
127 |
+
d_labels.view(-1)
|
128 |
+
)
|
129 |
+
loss_labels = self.loss_fn(
|
130 |
+
logits_labels.view(-1, self.config.num_labels - 1),
|
131 |
+
labels.view(-1)
|
132 |
+
)
|
133 |
+
loss = loss_d + loss_labels
|
134 |
+
|
135 |
+
pred_labels = torch.argmax(logits_labels, dim=-1)
|
136 |
+
accuracy = torch.sum(
|
137 |
+
(labels == pred_labels) * word_masks
|
138 |
+
) / torch.sum(word_masks)
|
139 |
+
pred_d = torch.argmax(logits_d, dim=-1)
|
140 |
+
accuracy_d = torch.sum(
|
141 |
+
(d_labels == pred_d) * word_masks
|
142 |
+
) / torch.sum(word_masks)
|
143 |
+
|
144 |
+
return GECToROutput(
|
145 |
+
loss=loss,
|
146 |
+
loss_d=loss_d,
|
147 |
+
loss_labels=loss_labels,
|
148 |
+
logits_d=logits_d,
|
149 |
+
logits_labels=logits_labels,
|
150 |
+
accuracy=accuracy,
|
151 |
+
accuracy_d=accuracy_d
|
152 |
+
)
|
153 |
+
|
154 |
+
def predict(
|
155 |
+
self,
|
156 |
+
input_ids: torch.Tensor,
|
157 |
+
attention_mask: torch.Tensor,
|
158 |
+
word_masks: torch.Tensor,
|
159 |
+
keep_confidence: float=0,
|
160 |
+
min_error_prob: float=0
|
161 |
+
):
|
162 |
+
with torch.no_grad():
|
163 |
+
outputs = self.forward(
|
164 |
+
input_ids,
|
165 |
+
attention_mask
|
166 |
+
)
|
167 |
+
probability_labels = F.softmax(outputs.logits_labels, dim=-1)
|
168 |
+
probability_d = F.softmax(outputs.logits_d, dim=-1)
|
169 |
+
|
170 |
+
# Get actual labels considering inference parameters.
|
171 |
+
keep_index = self.config.label2id[self.config.keep_label]
|
172 |
+
probability_labels[:, :, keep_index] += keep_confidence
|
173 |
+
incor_idx = self.config.d_label2id[self.config.incorrect_label]
|
174 |
+
probability_d = probability_d[:, :, incor_idx]
|
175 |
+
max_error_probability = torch.max(probability_d * word_masks, dim=-1)[0]
|
176 |
+
probability_labels[max_error_probability < min_error_prob, :, keep_index] \
|
177 |
+
= float('inf')
|
178 |
+
pred_label_ids = torch.argmax(probability_labels, dim=-1)
|
179 |
+
|
180 |
+
def convert_ids_to_labels(ids, id2label):
|
181 |
+
labels = []
|
182 |
+
for id in ids.tolist():
|
183 |
+
labels.append(id2label[id])
|
184 |
+
return labels
|
185 |
+
|
186 |
+
pred_labels = []
|
187 |
+
for ids in pred_label_ids:
|
188 |
+
labels = convert_ids_to_labels(
|
189 |
+
ids,
|
190 |
+
self.config.id2label
|
191 |
+
)
|
192 |
+
pred_labels.append(labels)
|
193 |
+
|
194 |
+
return GECToRPredictionOutput(
|
195 |
+
probability_labels=probability_labels,
|
196 |
+
probability_d=probability_d,
|
197 |
+
pred_labels=pred_labels,
|
198 |
+
pred_label_ids=pred_label_ids,
|
199 |
+
max_error_probability=max_error_probability
|
200 |
+
)
|
gector/predict.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
from .modeling import GECToR
|
5 |
+
from transformers import PreTrainedTokenizer
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
def load_verb_dict(verb_file: str):
|
9 |
+
path_to_dict = os.path.join(verb_file)
|
10 |
+
encode, decode = {}, {}
|
11 |
+
with open(path_to_dict, encoding="utf-8") as f:
|
12 |
+
for line in f:
|
13 |
+
words, tags = line.split(":")
|
14 |
+
word1, word2 = words.split("_")
|
15 |
+
tag1, tag2 = tags.split("_")
|
16 |
+
decode_key = f"{word1}_{tag1}_{tag2.strip()}"
|
17 |
+
if decode_key not in decode:
|
18 |
+
encode[words] = tags
|
19 |
+
decode[decode_key] = word2
|
20 |
+
return encode, decode
|
21 |
+
|
22 |
+
def edit_src_by_tags(
|
23 |
+
srcs: List[List[str]],
|
24 |
+
pred_labels: List[List[str]],
|
25 |
+
encode: dict,
|
26 |
+
decode: dict
|
27 |
+
) -> List[str]:
|
28 |
+
edited_srcs = []
|
29 |
+
for tokens, labels in zip(srcs, pred_labels):
|
30 |
+
edited_tokens = []
|
31 |
+
for t, l, in zip(tokens, labels):
|
32 |
+
n_token = process_token(t, l, encode, decode)
|
33 |
+
if n_token == None:
|
34 |
+
n_token = t
|
35 |
+
edited_tokens += n_token.split(' ')
|
36 |
+
if len(tokens) > len(labels):
|
37 |
+
omitted_tokens = tokens[len(labels):]
|
38 |
+
edited_tokens += omitted_tokens
|
39 |
+
temp_str = ' '.join(edited_tokens) \
|
40 |
+
.replace(' $MERGE_HYPHEN ', '-') \
|
41 |
+
.replace(' $MERGE_SPACE ', '') \
|
42 |
+
.replace(' $DELETE', '') \
|
43 |
+
.replace('$DELETE ', '')
|
44 |
+
edited_srcs.append(temp_str.split(' '))
|
45 |
+
return edited_srcs
|
46 |
+
|
47 |
+
def process_token(
|
48 |
+
token: str,
|
49 |
+
label: str,
|
50 |
+
encode: dict,
|
51 |
+
decode: dict
|
52 |
+
) -> str:
|
53 |
+
if '$APPEND_' in label:
|
54 |
+
return token + ' ' + label.replace('$APPEND_', '')
|
55 |
+
elif token == '$START':
|
56 |
+
# [unused1] token cannot be replaced with another token and cannot be deleted.
|
57 |
+
return token
|
58 |
+
elif label in ['<PAD>', '<OOV>', '$KEEP']:
|
59 |
+
return token
|
60 |
+
elif '$APPEND_' in label:
|
61 |
+
return token + ' ' + label.replace('$APPEND_', '')
|
62 |
+
elif '$TRANSFORM_' in label:
|
63 |
+
return g_transform_processer(token, label, encode, decode)
|
64 |
+
elif '$REPLACE_' in label:
|
65 |
+
return label.replace('$REPLACE_', '')
|
66 |
+
elif label == '$DELETE':
|
67 |
+
return label
|
68 |
+
elif '$MERGE_' in label:
|
69 |
+
return token + ' ' + label
|
70 |
+
else:
|
71 |
+
return token
|
72 |
+
|
73 |
+
def g_transform_processer(
|
74 |
+
token: str,
|
75 |
+
label: str,
|
76 |
+
encode: dict,
|
77 |
+
decode: dict
|
78 |
+
) -> str:
|
79 |
+
# Case related
|
80 |
+
if label == '$TRANSFORM_CASE_LOWER':
|
81 |
+
return token.lower()
|
82 |
+
elif label == '$TRANSFORM_CASE_UPPER':
|
83 |
+
return token.upper()
|
84 |
+
elif label == '$TRANSFORM_CASE_CAPITAL':
|
85 |
+
return token.capitalize()
|
86 |
+
elif label == '$TRANSFORM_CASE_CAPITAL_1':
|
87 |
+
if len(token) <= 1:
|
88 |
+
return token
|
89 |
+
return token[0] + token[1:].capitalize()
|
90 |
+
elif label == '$TRANSFORM_AGREEMENT_PLURAL':
|
91 |
+
return token + 's'
|
92 |
+
elif label == '$TRANSFORM_AGREEMENT_SINGULAR':
|
93 |
+
return token[:-1]
|
94 |
+
elif label == '$TRANSFORM_SPLIT_HYPHEN':
|
95 |
+
return ' '.join(token.split('-'))
|
96 |
+
else:
|
97 |
+
encoding_part = f"{token}_{label[len('$TRANSFORM_VERB_'):]}"
|
98 |
+
decoded_target_word = decode.get(encoding_part)
|
99 |
+
return decoded_target_word
|
100 |
+
|
101 |
+
def get_word_masks_from_word_ids(
|
102 |
+
word_ids: List[List[int]],
|
103 |
+
n: int
|
104 |
+
):
|
105 |
+
word_masks = []
|
106 |
+
for i in range(n):
|
107 |
+
previous_id = 0
|
108 |
+
mask = []
|
109 |
+
for _id in word_ids(i):
|
110 |
+
if _id is None:
|
111 |
+
mask.append(0)
|
112 |
+
elif previous_id != _id:
|
113 |
+
mask.append(1)
|
114 |
+
else:
|
115 |
+
mask.append(0)
|
116 |
+
previous_id = _id
|
117 |
+
word_masks.append(mask)
|
118 |
+
return word_masks
|
119 |
+
|
120 |
+
def _predict(
|
121 |
+
model: GECToR,
|
122 |
+
tokenizer: PreTrainedTokenizer,
|
123 |
+
srcs: List[str],
|
124 |
+
keep_confidence: float=0,
|
125 |
+
min_error_prob: float=0,
|
126 |
+
batch_size: int=128
|
127 |
+
):
|
128 |
+
itr = list(range(0, len(srcs), batch_size))
|
129 |
+
pred_labels = []
|
130 |
+
no_corrections = []
|
131 |
+
for i in tqdm(itr):
|
132 |
+
batch = tokenizer(
|
133 |
+
srcs[i:i+batch_size],
|
134 |
+
return_tensors='pt',
|
135 |
+
max_length=model.config.max_length,
|
136 |
+
padding='max_length',
|
137 |
+
truncation=True,
|
138 |
+
is_split_into_words=True
|
139 |
+
)
|
140 |
+
batch['word_masks'] = torch.tensor(
|
141 |
+
get_word_masks_from_word_ids(
|
142 |
+
batch.word_ids,
|
143 |
+
batch['input_ids'].size(0)
|
144 |
+
)
|
145 |
+
)
|
146 |
+
word_ids = batch.word_ids
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
batch = {k:v.cuda() for k,v in batch.items()}
|
149 |
+
outputs = model.predict(
|
150 |
+
batch['input_ids'],
|
151 |
+
batch['attention_mask'],
|
152 |
+
batch['word_masks'],
|
153 |
+
keep_confidence,
|
154 |
+
min_error_prob
|
155 |
+
)
|
156 |
+
# Align subword-level label to word-level label
|
157 |
+
for i in range(len(outputs.pred_labels)):
|
158 |
+
no_correct = True
|
159 |
+
labels = []
|
160 |
+
previous_word_idx = None
|
161 |
+
for j, idx in enumerate(word_ids(i)):
|
162 |
+
if idx is None:
|
163 |
+
continue
|
164 |
+
if idx != previous_word_idx:
|
165 |
+
labels.append(outputs.pred_labels[i][j])
|
166 |
+
if outputs.pred_label_ids[i][j] > 2:
|
167 |
+
no_correct = False
|
168 |
+
previous_word_idx = idx
|
169 |
+
# print(no_correct, labels)
|
170 |
+
pred_labels.append(labels)
|
171 |
+
no_corrections.append(no_correct)
|
172 |
+
# print(pred_labels)
|
173 |
+
return pred_labels, no_corrections
|
174 |
+
|
175 |
+
def predict(
|
176 |
+
model: GECToR,
|
177 |
+
tokenizer: PreTrainedTokenizer,
|
178 |
+
srcs: List[str],
|
179 |
+
encode: dict,
|
180 |
+
decode: dict,
|
181 |
+
keep_confidence: float=0,
|
182 |
+
min_error_prob: float=0,
|
183 |
+
batch_size: int=128,
|
184 |
+
n_iteration: int=5
|
185 |
+
) -> List[str]:
|
186 |
+
srcs = [['$START'] + src.split(' ') for src in srcs]
|
187 |
+
final_edited_sents = ['-1'] * len(srcs)
|
188 |
+
to_be_processed = srcs
|
189 |
+
original_sent_idx = list(range(0, len(srcs)))
|
190 |
+
for itr in range(n_iteration):
|
191 |
+
print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
|
192 |
+
pred_labels, no_corrections = _predict(
|
193 |
+
model,
|
194 |
+
tokenizer,
|
195 |
+
to_be_processed,
|
196 |
+
keep_confidence,
|
197 |
+
min_error_prob,
|
198 |
+
batch_size
|
199 |
+
)
|
200 |
+
current_srcs = []
|
201 |
+
current_pred_labels = []
|
202 |
+
current_orig_idx = []
|
203 |
+
for i, yes in enumerate(no_corrections):
|
204 |
+
if yes: # there's no corrections?
|
205 |
+
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
|
206 |
+
else:
|
207 |
+
current_srcs.append(to_be_processed[i])
|
208 |
+
current_pred_labels.append(pred_labels[i])
|
209 |
+
current_orig_idx.append(original_sent_idx[i])
|
210 |
+
if current_srcs == []:
|
211 |
+
# Correcting for all sentences is completed.
|
212 |
+
break
|
213 |
+
# if itr > 2:
|
214 |
+
# for l in current_pred_labels:
|
215 |
+
# print(l)
|
216 |
+
edited_srcs = edit_src_by_tags(
|
217 |
+
current_srcs,
|
218 |
+
current_pred_labels,
|
219 |
+
encode,
|
220 |
+
decode
|
221 |
+
)
|
222 |
+
to_be_processed = edited_srcs
|
223 |
+
original_sent_idx = current_orig_idx
|
224 |
+
|
225 |
+
# print(f'=== Iteration {itr} ===')
|
226 |
+
# print('\n'.join(final_edited_sents))
|
227 |
+
# print(to_be_processed)
|
228 |
+
# print(have_corrections)
|
229 |
+
for i in range(len(to_be_processed)):
|
230 |
+
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
|
231 |
+
assert('-1' not in final_edited_sents)
|
232 |
+
return final_edited_sents
|
gector/predict_verbose.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
from .modeling import GECToR
|
5 |
+
from transformers import PreTrainedTokenizer
|
6 |
+
from typing import List, Dict
|
7 |
+
from .predict import (
|
8 |
+
edit_src_by_tags,
|
9 |
+
_predict
|
10 |
+
)
|
11 |
+
|
12 |
+
def predict_verbose(
|
13 |
+
model: GECToR,
|
14 |
+
tokenizer: PreTrainedTokenizer,
|
15 |
+
srcs: List[str],
|
16 |
+
encode: dict,
|
17 |
+
decode: dict,
|
18 |
+
keep_confidence: float=0,
|
19 |
+
min_error_prob: float=0,
|
20 |
+
batch_size: int=128,
|
21 |
+
n_iteration: int=5
|
22 |
+
) -> List[str]:
|
23 |
+
srcs = [['$START'] + src.split(' ') for src in srcs]
|
24 |
+
final_edited_sents = ['-1'] * len(srcs)
|
25 |
+
to_be_processed = srcs
|
26 |
+
original_sent_idx = list(range(0, len(srcs)))
|
27 |
+
iteration_log: List[List[Dict]] = [] # [send_id][iteration_id]['src' or 'tags']
|
28 |
+
iteration_log = []
|
29 |
+
# Initialize iteration logs.
|
30 |
+
for i, src in enumerate(srcs):
|
31 |
+
iteration_log.append([{
|
32 |
+
'src': src,
|
33 |
+
'tag': None
|
34 |
+
}])
|
35 |
+
for itr in range(n_iteration):
|
36 |
+
print(f'Iteratoin {itr}. the number of to_be_processed: {len(to_be_processed)}')
|
37 |
+
pred_labels, no_corrections = _predict(
|
38 |
+
model,
|
39 |
+
tokenizer,
|
40 |
+
to_be_processed,
|
41 |
+
keep_confidence,
|
42 |
+
min_error_prob,
|
43 |
+
batch_size
|
44 |
+
)
|
45 |
+
current_srcs = []
|
46 |
+
current_pred_labels = []
|
47 |
+
current_orig_idx = []
|
48 |
+
for i, yes in enumerate(no_corrections):
|
49 |
+
if yes: # there's no corrections?
|
50 |
+
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
|
51 |
+
else:
|
52 |
+
current_srcs.append(to_be_processed[i])
|
53 |
+
current_pred_labels.append(pred_labels[i])
|
54 |
+
current_orig_idx.append(original_sent_idx[i])
|
55 |
+
if current_srcs == []:
|
56 |
+
# Correcting for all sentences is completed.
|
57 |
+
break
|
58 |
+
edited_srcs = edit_src_by_tags(
|
59 |
+
current_srcs,
|
60 |
+
current_pred_labels,
|
61 |
+
encode,
|
62 |
+
decode
|
63 |
+
)
|
64 |
+
# Register the information during iteration.
|
65 |
+
# edited_src will be the src of the next iteration.
|
66 |
+
for i, orig_id in enumerate(current_orig_idx):
|
67 |
+
iteration_log[orig_id][itr]['tag'] = current_pred_labels[i]
|
68 |
+
iteration_log[orig_id].append({
|
69 |
+
'src': edited_srcs[i],
|
70 |
+
'tag': None
|
71 |
+
})
|
72 |
+
|
73 |
+
to_be_processed = edited_srcs
|
74 |
+
original_sent_idx = current_orig_idx
|
75 |
+
|
76 |
+
# print(f'=== Iteration {itr} ===')
|
77 |
+
# print('\n'.join(final_edited_sents))
|
78 |
+
# print(to_be_processed)
|
79 |
+
# print(have_corrections)
|
80 |
+
for i in range(len(to_be_processed)):
|
81 |
+
final_edited_sents[original_sent_idx[i]] = ' '.join(to_be_processed[i]).replace('$START ', '')
|
82 |
+
assert('-1' not in final_edited_sents)
|
83 |
+
return final_edited_sents, iteration_log
|
gector/vocab.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration import GECToRConfig
|
2 |
+
from .dataset import GECToRDataset
|
3 |
+
import os
|
4 |
+
|
5 |
+
def build_vocab(
|
6 |
+
train_dataset: GECToRDataset,
|
7 |
+
n_max_labels: int=5000,
|
8 |
+
n_max_d_labels: int=2
|
9 |
+
):
|
10 |
+
label2id = {'<OOV>':0, '$KEEP':1}
|
11 |
+
d_label2id = {'$CORRECT':0, '$INCORRECT':1, '<PAD>':2}
|
12 |
+
freq_labels, _ = train_dataset.get_labels_freq(
|
13 |
+
exluded_labels=['<PAD>'] + list(label2id.keys())
|
14 |
+
)
|
15 |
+
|
16 |
+
def get_high_freq(freq: dict, n_max: int):
|
17 |
+
descending_freq = sorted(
|
18 |
+
freq.items(), key=lambda x:x[1], reverse=True
|
19 |
+
)
|
20 |
+
high_freq = [x[0] for x in descending_freq][:n_max]
|
21 |
+
if len(high_freq) < n_max:
|
22 |
+
print(f'Warning: the size of the vocablary: {len(high_freq)} is less than n_max: {n_max}.')
|
23 |
+
return high_freq
|
24 |
+
|
25 |
+
high_freq_labels = get_high_freq(freq_labels, n_max_labels-2)
|
26 |
+
for i, x in enumerate(high_freq_labels):
|
27 |
+
label2id[x] = i + 2
|
28 |
+
label2id['<PAD>'] = len(label2id)
|
29 |
+
return label2id, d_label2id
|
30 |
+
|
31 |
+
def load_vocab_from_config(config_file: str):
|
32 |
+
config = GECToRConfig.from_pretrained(config_file, not_dir=True)
|
33 |
+
return config.label2id, config.d_label2id
|
34 |
+
|
35 |
+
def load_vocab_from_official(dir):
|
36 |
+
vocab_path = os.path.join(dir, 'labels.txt')
|
37 |
+
vocab = open(vocab_path).read().replace('@@PADDING@@', '').replace('@@UNKNOWN@@', '').rstrip().split('\n')
|
38 |
+
# vocab_d = open(dir + 'd_tags.txt').read().rstrip().replace('@@PADDING@@', '<PAD>').replace('@@UNKNOWN@@', '<OOV>').split('\n')
|
39 |
+
label2id = {'<OOV>':0, '$KEEP':1}
|
40 |
+
d_label2id = {'$CORRECT':0, '$INCORRECT':1, '<PAD>':2}
|
41 |
+
idx = len(label2id)
|
42 |
+
for v in vocab:
|
43 |
+
if v not in label2id:
|
44 |
+
label2id[v] = idx
|
45 |
+
idx += 1
|
46 |
+
label2id['<PAD>'] = idx
|
47 |
+
return label2id, d_label2id
|
48 |
+
|
handler.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from gector import GECToR, predict, load_verb_dict
|
4 |
+
|
5 |
+
|
6 |
+
class EndpointHandler:
|
7 |
+
def __init__(self, path=""):
|
8 |
+
self.model = GECToR.from_pretrained(path)
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
10 |
+
self.encode, self.decode = load_verb_dict("data/verb-form-vocab.txt")
|
11 |
+
|
12 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
13 |
+
"""
|
14 |
+
Process the input data and return the predicted results.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
data (Dict[str, Any]): The input data dictionary containing the following keys:
|
18 |
+
- "inputs" (List[str]): A list of input strings to be processed.
|
19 |
+
- "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
|
20 |
+
- "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
|
21 |
+
- "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
|
22 |
+
- "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
|
26 |
+
"""
|
27 |
+
srcs = data["inputs"]
|
28 |
+
|
29 |
+
# Extract optional parameters from data, with defaults
|
30 |
+
n_iterations = data.get("n_iterations", 5)
|
31 |
+
batch_size = data.get("batch_size", 2)
|
32 |
+
keep_confidence = data.get("keep_confidence", 0.0)
|
33 |
+
min_error_prob = data.get("min_error_prob", 0.0)
|
34 |
+
|
35 |
+
return predict(
|
36 |
+
model=self.model,
|
37 |
+
tokenizer=self.tokenizer,
|
38 |
+
srcs=srcs,
|
39 |
+
encode=self.encode,
|
40 |
+
decode=self.decode,
|
41 |
+
keep_confidence=keep_confidence,
|
42 |
+
min_error_prob=min_error_prob,
|
43 |
+
n_iteration=n_iterations,
|
44 |
+
batch_size=batch_size,
|
45 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.27.0
|
2 |
+
certifi==2024.2.2
|
3 |
+
charset-normalizer==3.3.2
|
4 |
+
filelock==3.13.1
|
5 |
+
fsspec==2024.2.0
|
6 |
+
huggingface-hub==0.20.3
|
7 |
+
idna==3.6
|
8 |
+
Jinja2==3.1.3
|
9 |
+
Levenshtein==0.24.0
|
10 |
+
MarkupSafe==2.1.5
|
11 |
+
mpmath==1.3.0
|
12 |
+
networkx==3.2.1
|
13 |
+
numpy==1.26.4
|
14 |
+
packaging==23.2
|
15 |
+
psutil==5.9.8
|
16 |
+
PyYAML==6.0.1
|
17 |
+
rapidfuzz==3.6.1
|
18 |
+
regex==2023.12.25
|
19 |
+
requests==2.31.0
|
20 |
+
safetensors==0.4.2
|
21 |
+
sympy==1.12
|
22 |
+
tokenizers==0.15.1
|
23 |
+
torch==2.2.0
|
24 |
+
tqdm==4.66.2
|
25 |
+
transformers==4.37.2
|
26 |
+
typing_extensions==4.9.0
|
27 |
+
urllib3==2.2.0
|