bkhmsi commited on
Commit
ebc546a
·
1 Parent(s): cd87bdb

pdd working now

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +19 -8
  3. model_partial.py +34 -16
  4. partial_dd_metrics.py +329 -0
  5. predict.py +75 -4
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  *.pyc
2
  *.pt
3
  *.vec
 
4
  .DS_Store
 
1
  *.pyc
2
  *.pt
3
  *.vec
4
+ *.pem
5
  .DS_Store
app.py CHANGED
@@ -3,6 +3,7 @@ import yaml
3
  import gdown
4
  import gradio as gr
5
  from predict import PredictTri
 
6
 
7
  output_path = "tashkeela-d2.pt"
8
  if not os.path.exists(output_path):
@@ -20,18 +21,20 @@ with open("config.yaml", 'r', encoding="utf-8") as file:
20
  config["train"]["max-sent-len"] = config["predictor"]["window"]
21
  config["train"]["max-token-count"] = config["predictor"]["window"] * 3
22
 
23
- def diacritze(text):
24
- print(text)
25
  predictor = PredictTri(config, text)
26
- diacritized_lines = predictor.predict_majority_vote()
27
- return '\n'.join(diacritized_lines)
28
 
29
  with gr.Blocks() as demo:
30
  gr.Markdown(
31
  """
32
- # Partial Diacritization
33
- TODO: put paper links here
34
  """)
 
 
 
35
  input_txt = gr.Textbox(
36
  placeholder="اكتب هنا",
37
  lines=5,
@@ -50,7 +53,15 @@ with gr.Blocks() as demo:
50
  )
51
 
52
  btn = gr.Button(value="Shakkel")
53
- btn.click(diacritze, inputs=input_txt, outputs=output_txt)
54
 
55
  if __name__ == "__main__":
56
- demo.launch()
 
 
 
 
 
 
 
 
 
3
  import gdown
4
  import gradio as gr
5
  from predict import PredictTri
6
+ from gradio import blocks
7
 
8
  output_path = "tashkeela-d2.pt"
9
  if not os.path.exists(output_path):
 
21
  config["train"]["max-sent-len"] = config["predictor"]["window"]
22
  config["train"]["max-token-count"] = config["predictor"]["window"] * 3
23
 
24
+ def diacritze(text, do_partial):
 
25
  predictor = PredictTri(config, text)
26
+ diacritized_lines = predictor.predict_partial(do_partial=do_partial)
27
+ return diacritized_lines
28
 
29
  with gr.Blocks() as demo:
30
  gr.Markdown(
31
  """
32
+ # Partial Diacritization: A Context-Contrastive Inference Approach
33
+ ## Authors: Muhammad ElNokrashy, Badr AlKhamissi
34
  """)
35
+
36
+ check_box = gr.Checkbox(label="Partial", info="Apply Partial Diacritics or Full Diacritics")
37
+
38
  input_txt = gr.Textbox(
39
  placeholder="اكتب هنا",
40
  lines=5,
 
53
  )
54
 
55
  btn = gr.Button(value="Shakkel")
56
+ btn.click(diacritze, inputs=[input_txt, check_box], outputs=[output_txt])
57
 
58
  if __name__ == "__main__":
59
+ demo.queue().launch(
60
+ # share=False,
61
+ # debug=False,
62
+ # server_port=7860,
63
+ # server_name="0.0.0.0",
64
+ # ssl_verify=False,
65
+ # ssl_certfile="cert.pem",
66
+ # ssl_keyfile="key.pem"
67
+ )
model_partial.py CHANGED
@@ -5,10 +5,11 @@ import numpy as np
5
 
6
  import torch as T
7
  from torch import nn
8
- from torch import functional as F
9
  from diac_utils import flat_2_3head
10
 
11
  from model_dd import DiacritizerD2
 
12
 
13
  class Readout(nn.Module):
14
  def __init__(
@@ -56,24 +57,27 @@ class PartialDiacOutput(NamedTuple):
56
  preds_hard: T.Tensor
57
  preds_ctxt_logit: T.Tensor
58
  preds_base_logit: T.Tensor
59
-
60
 
61
  class PartialDD(nn.Module):
62
  def __init__(
63
  self,
64
  config: dict,
65
- # feature_size: int,
66
- # confidence_threshold: float,
67
- d2=False
68
  ):
69
  super().__init__()
70
  self._built = False
71
  self.no_diac_id = 0
72
  self._dummy = nn.Parameter(T.ones(1, 1))
73
-
 
 
74
  self.config = config
 
75
  self.sentence_diac = DiacritizerD2(self.config)
76
-
 
 
 
77
  self.eval()
78
 
79
  @property
@@ -114,6 +118,7 @@ class PartialDD(nn.Module):
114
 
115
  return toke_ids, char_ids, diac_ids, subword_lengths
116
 
 
117
  def word_diac(
118
  self,
119
  toke_ids: T.Tensor,
@@ -169,6 +174,7 @@ class PartialDD(nn.Module):
169
  z = z.reshape(Nb, Tw, Tc, -1)
170
  return z
171
 
 
172
  def forward(
173
  self,
174
  word_ids: T.Tensor,
@@ -178,8 +184,9 @@ class PartialDD(nn.Module):
178
  # padding_mask: T.BoolTensor,
179
  *,
180
  eval_only: str = None,
181
- subword_lengths: T.Tensor = None,
182
- return_extra: bool = False
 
183
  ):
184
  # assert self._built and not self.training
185
  assert not self.training
@@ -195,6 +202,7 @@ class PartialDD(nn.Module):
195
  word_ids,
196
  char_ids,
197
  _labels,
 
198
  )
199
  out_shape = y_ctxt.shape[:-1]
200
  else:
@@ -219,6 +227,7 @@ class PartialDD(nn.Module):
219
  if eval_only == 'base':
220
  return y_base.argmax(-1)
221
 
 
222
  ypred_ctxt = y_ctxt.argmax(-1)
223
  ypred_base = y_base.argmax(-1)
224
  #^ ypred: [b tw tc _]
@@ -226,7 +235,9 @@ class PartialDD(nn.Module):
226
  # Maybe for eval
227
  # ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id
228
  # return ypred_ctxt
229
- ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id
 
 
230
  if not return_extra:
231
  return ypred_ctxt
232
  else:
@@ -250,6 +261,7 @@ class PartialDD(nn.Module):
250
  dataloader,
251
  return_extra=False,
252
  eval_only: str = None,
 
253
  ):
254
  training = self.training
255
  self.eval()
@@ -261,10 +273,11 @@ class PartialDD(nn.Module):
261
  'diacs': [],
262
  'y_ctxt': [],
263
  'y_base': [],
 
264
  }
265
  print("> Predicting...")
266
  # breakpoint()
267
- for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)):
268
  # if i_batch > 10:
269
  # break
270
  #^ inputs: [toke_ids, char_ids, diac_ids]
@@ -282,15 +295,19 @@ class PartialDD(nn.Module):
282
  subword_lengths=subword_lengths,
283
  return_extra=return_extra,
284
  eval_only=eval_only,
 
285
  )
286
 
287
  # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
288
  if return_extra:
289
  assert isinstance(output, PartialDiacOutput)
290
  marks = output.preds_hard
 
 
291
  preds['diacs'].extend(list(marks.detach().cpu().numpy()))
292
  preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy()))
293
  preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy()))
 
294
  else:
295
  assert isinstance(output, T.Tensor)
296
  marks = output
@@ -312,9 +329,10 @@ class PartialDD(nn.Module):
312
  np.array(preds["shadda"]),
313
  ),
314
  'other': ( # Would be empty when !return_extra
315
- preds['y_ctxt'],
316
- preds['y_base'],
317
- preds['diacs'],
 
318
  )
319
  }
320
 
@@ -327,7 +345,7 @@ class PartialDD(nn.Module):
327
  for inputs, _ in tqdm(dataloader, total=len(dataloader)):
328
  inputs[0] = inputs[0].to(self.device)
329
  inputs[1] = inputs[1].to(self.device)
330
- output = self(*inputs, eval_only='ctxt')
331
 
332
  # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
333
  marks = output
@@ -344,4 +362,4 @@ class PartialDD(nn.Module):
344
  np.array(preds['haraka']),
345
  np.array(preds["tanween"]),
346
  np.array(preds["shadda"]),
347
- )
 
5
 
6
  import torch as T
7
  from torch import nn
8
+ from torch.nn import functional as F
9
  from diac_utils import flat_2_3head
10
 
11
  from model_dd import DiacritizerD2
12
+ from model_dd import DatasetUtils
13
 
14
  class Readout(nn.Module):
15
  def __init__(
 
57
  preds_hard: T.Tensor
58
  preds_ctxt_logit: T.Tensor
59
  preds_base_logit: T.Tensor
 
60
 
61
  class PartialDD(nn.Module):
62
  def __init__(
63
  self,
64
  config: dict,
65
+ **kwargs
 
 
66
  ):
67
  super().__init__()
68
  self._built = False
69
  self.no_diac_id = 0
70
  self._dummy = nn.Parameter(T.ones(1, 1))
71
+ # with open('./configs/dd/config_d2.yaml', 'r', encoding='utf-8') as fin:
72
+ # self.config_d2 = yaml.safe_load(fin)
73
+ # self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
74
  self.config = config
75
+ self._use_d2 = True
76
  self.sentence_diac = DiacritizerD2(self.config)
77
+
78
+ # self.sentence_diac.to(self.device)
79
+ # self.build()
80
+ # self.word_diac = WordDD_LSTM(feature_size, num_classes=13, return_logits=False)
81
  self.eval()
82
 
83
  @property
 
118
 
119
  return toke_ids, char_ids, diac_ids, subword_lengths
120
 
121
+ T.jit.export
122
  def word_diac(
123
  self,
124
  toke_ids: T.Tensor,
 
174
  z = z.reshape(Nb, Tw, Tc, -1)
175
  return z
176
 
177
+ T.jit.ignore
178
  def forward(
179
  self,
180
  word_ids: T.Tensor,
 
184
  # padding_mask: T.BoolTensor,
185
  *,
186
  eval_only: str = None,
187
+ subword_lengths: T.Tensor,
188
+ return_extra: bool = False,
189
+ do_partial: bool = False,
190
  ):
191
  # assert self._built and not self.training
192
  assert not self.training
 
202
  word_ids,
203
  char_ids,
204
  _labels,
205
+ subword_lengths=subword_lengths,
206
  )
207
  out_shape = y_ctxt.shape[:-1]
208
  else:
 
227
  if eval_only == 'base':
228
  return y_base.argmax(-1)
229
 
230
+ #! TODO: Return the logits.
231
  ypred_ctxt = y_ctxt.argmax(-1)
232
  ypred_base = y_base.argmax(-1)
233
  #^ ypred: [b tw tc _]
 
235
  # Maybe for eval
236
  # ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id
237
  # return ypred_ctxt
238
+ if do_partial:
239
+ ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id
240
+
241
  if not return_extra:
242
  return ypred_ctxt
243
  else:
 
261
  dataloader,
262
  return_extra=False,
263
  eval_only: str = None,
264
+ do_partial=True,
265
  ):
266
  training = self.training
267
  self.eval()
 
273
  'diacs': [],
274
  'y_ctxt': [],
275
  'y_base': [],
276
+ 'subword_lengths': [],
277
  }
278
  print("> Predicting...")
279
  # breakpoint()
280
+ for i_batch, (inputs, _) in enumerate(tqdm(dataloader)):
281
  # if i_batch > 10:
282
  # break
283
  #^ inputs: [toke_ids, char_ids, diac_ids]
 
295
  subword_lengths=subword_lengths,
296
  return_extra=return_extra,
297
  eval_only=eval_only,
298
+ do_partial=do_partial,
299
  )
300
 
301
  # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
302
  if return_extra:
303
  assert isinstance(output, PartialDiacOutput)
304
  marks = output.preds_hard
305
+ if eval_only == 'recalibrated':
306
+ marks = (output.preds_ctxt_logit + output.preds_base_logit).argmax(-1)
307
  preds['diacs'].extend(list(marks.detach().cpu().numpy()))
308
  preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy()))
309
  preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy()))
310
+ preds['subword_lengths'].extend(list(subword_lengths.detach().cpu().numpy()))
311
  else:
312
  assert isinstance(output, T.Tensor)
313
  marks = output
 
329
  np.array(preds["shadda"]),
330
  ),
331
  'other': ( # Would be empty when !return_extra
332
+ np.array(preds['y_ctxt']),
333
+ np.array(preds['y_base']),
334
+ np.array(preds['diacs']),
335
+ np.array(preds['subword_lengths']),
336
  )
337
  }
338
 
 
345
  for inputs, _ in tqdm(dataloader, total=len(dataloader)):
346
  inputs[0] = inputs[0].to(self.device)
347
  inputs[1] = inputs[1].to(self.device)
348
+ output = self(*inputs)
349
 
350
  # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
351
  marks = output
 
362
  np.array(preds['haraka']),
363
  np.array(preds["tanween"]),
364
  np.array(preds["shadda"]),
365
+ )
partial_dd_metrics.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+ from argparse import ArgumentParser
3
+
4
+ from tqdm import tqdm
5
+ import logging
6
+
7
+ import numpy as np
8
+ import torch as T
9
+ from torch.nn import functional as F
10
+
11
+ import diac_utils as du
12
+
13
+ _x = [
14
+ 'a'
15
+ ]
16
+
17
+ # logging.setLevel(logging.INFO)
18
+ logger = logging.getLogger(__file__)
19
+ logger.setLevel(logging.INFO)
20
+
21
+ def logln(*texts: str):
22
+ # logger.info(' '.join(texts))
23
+ print(*texts)
24
+
25
+ # Relative improvement:
26
+ # T.mean((pred_c.argmax('c') == gt) - (pred_m.argmax('c') == gt))
27
+ # Coverage Confidence:
28
+ # pred_c.argmax('c')[pred_c.argmax('c') != pred_m.argmax('c')].mean()
29
+
30
+ class PartialDiacMetrics(NamedTuple):
31
+ diff_total: float
32
+ worse_total: float
33
+ diff_relative: float
34
+ der_total: float
35
+ selectivity: float
36
+ hidden_der: float
37
+ partial_der: float
38
+ reader_error: float
39
+
40
+ def load_data(path: str):
41
+ if path.endswith('.txt'):
42
+ with open(path, 'r', encoding='utf-8') as fin:
43
+ return fin.readlines()
44
+ else:
45
+ return T.load(path)
46
+
47
+ def parse_data(
48
+ data,
49
+ logits: bool = False,
50
+ side=None,
51
+ ):
52
+ if logits:
53
+ ld = data['line_data']
54
+ diac_logits = T.tensor(ld[f'diac_logits_{side}'])
55
+ # diac_pred: T.Tensor = ld['diac_pred']
56
+ diac_pred: T.Tensor = diac_logits.argmax(dim=-1)
57
+ diac_gt : T.Tensor = ld['diac_gt']
58
+ # diac_logits = (ld['diac_logits_ctxt'], ld['diac_logits_base'])
59
+ return diac_pred, diac_gt, diac_logits
60
+ if isinstance(data, dict):
61
+ ld = data.get('line_data_fix', data['line_data'])
62
+ if side is None:
63
+ diac_pred: T.Tensor = ld['diac_pred']
64
+ else:
65
+ diac_pred: T.Tensor = ld[f'diac_logits_{side}'].argmax(axis=-1)
66
+ diac_gt : T.Tensor = ld['diac_gt']
67
+ return diac_pred, diac_gt
68
+ elif isinstance(data, list):
69
+ data_indices = [
70
+ du.diac_ids_of_line(du.strip_tatweel(du.normalize_spaces(line)))
71
+ for line in data
72
+ ]
73
+ max_len = max(map(len, data_indices))
74
+ out = np.full((len(data), max_len), fill_value=du.DIAC_PAD_IDX)
75
+ for i_line, line_indices in enumerate(data_indices):
76
+ out[i_line][:len(line_indices)] = line_indices
77
+ return out, None
78
+ elif isinstance(data, (T.Tensor, np.ndarray)):
79
+ return data, None
80
+ else:
81
+ raise NotImplementedError
82
+
83
+ def make_mask_hard(
84
+ pred_c: T.Tensor,
85
+ pred_m: T.Tensor,
86
+ ):
87
+ selection = (pred_c != pred_m)
88
+ return selection
89
+
90
+ def make_mask_logits(
91
+ pred_c: T.Tensor,
92
+ pred_m: T.Tensor,
93
+ threshold: float = 0.1,
94
+ version: str = '2',
95
+ ) -> T.BoolTensor:
96
+ logger.warning(f"{version=}, {threshold=}")
97
+ pred_c = T.softmax(T.tensor(pred_c), dim=-1)
98
+ pred_m = T.softmax(T.tensor(pred_m), dim=-1)
99
+ # pred_i = pred_c.argmax(dim=-1)
100
+ if version == 'hard':
101
+ selection = pred_c.argmax(-1) != pred_m.argmax(-1)
102
+ elif version == '0':
103
+ selection = pred_c.max(dim=-1).values > pred_m.max(dim=-1).values
104
+ selection = selection & (pred_m.max(dim=-1).values > threshold)
105
+ elif version == '1':
106
+ pred_c_conf = pred_c.max(dim=-1).values
107
+ pred_m_conf = pred_m.max(dim=-1).values
108
+ selection = (pred_c_conf - pred_m_conf) > threshold
109
+ elif version == '1.1':
110
+ pred_c_conf = pred_c.max(dim=-1).values
111
+ pred_m_conf = pred_m.max(dim=-1).values
112
+ selection = (pred_c_conf - pred_m_conf).abs() > threshold
113
+ elif version.startswith('2'):
114
+ if version == '2':
115
+ max_c = pred_c.argmax(dim=-1, keepdims=True)
116
+ selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold
117
+ elif version == '2.1':
118
+ max_c = pred_m.argmax(dim=-1, keepdims=True)
119
+ selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold
120
+ elif version == '2.abs':
121
+ max_c = pred_c.argmax(dim=-1, keepdims=True)
122
+ selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold
123
+ elif version == '2.1.abs':
124
+ max_c = pred_m.argmax(dim=-1, keepdims=True)
125
+ selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold
126
+ elif version == '3':
127
+ selection = (pred_c - pred_m).max(dim=-1).values > threshold
128
+ elif version == '4':
129
+ selection_hard = (pred_c.argmax(-1) != pred_m.argmax(-1))
130
+ # selection_logits = (pred_c.max(-1).values - pred_m.max(-1).values) > threshold
131
+ selection_logits = T.gather(pred_c - pred_m, dim=-1, index=pred_c.argmax(-1, keepdims=True)) > threshold
132
+ selection = selection_hard & selection_logits.squeeze()
133
+ # selection = (pred_c != pred_m)
134
+ return selection.squeeze()
135
+
136
+ def analysis_summary(
137
+ pred_c : T.LongTensor,
138
+ pred_m : T.LongTensor,
139
+ labels : T.LongTensor,
140
+ padding_mask: T.BoolTensor,
141
+ *,
142
+ selection : T.Tensor = None,
143
+ random: bool = False,
144
+ logits: tuple = None
145
+ ):
146
+ #^ pred_c: [b tw tc | ClassId]
147
+ #^ pred_m: [b tw tc | ClassId]
148
+ #^ labels: [b tw tc | ClassId]
149
+ padding_mask = T.tensor(padding_mask)
150
+ # padding_mask[:, 200:] = False
151
+ nonpad_mask = ~padding_mask
152
+ num_chars = nonpad_mask.sum()
153
+
154
+ if logits is not None:
155
+ logits = tuple(map(T.tensor, logits))
156
+ # pred_c = (logits[0] + logits[1]).argmax(-1)
157
+ pred_c = (T.softmax(logits[0], dim=-1) + T.softmax(logits[1], dim=-1)).argmax(-1)
158
+ pred_c = T.tensor(pred_c)[nonpad_mask]
159
+ pred_m = T.tensor(pred_m)[nonpad_mask]
160
+ labels = T.tensor(labels)[nonpad_mask]
161
+ #^ : [(b * tw * tc) | ClassId]
162
+
163
+ ctxt_match = (pred_c == labels).float()
164
+ base_match = (pred_m == labels).float()
165
+
166
+ selection = T.tensor(selection)[nonpad_mask]
167
+ if random:
168
+ selection = pred_c.new_empty(pred_c.shape).bernoulli_(p=selection.float().mean()).to(bool)
169
+ unselected = ~selection
170
+
171
+ assert num_chars > 0
172
+ assert selection.sum() > 0
173
+ base_accuracy = base_match[unselected].sum() / unselected.sum()
174
+ ctxt_accuracy = ctxt_match[selection].sum() / selection.sum()
175
+ correct_total = ctxt_match.sum() / num_chars
176
+ der_total = 1 - correct_total
177
+
178
+ cmp = (ctxt_match - base_match)[selection]
179
+ diff = T.sum(cmp)
180
+ diff_total = diff / num_chars
181
+ diff_relative = diff / selection.sum()
182
+
183
+ selectivity = selection.sum() / num_chars
184
+ worse_total = base_match[selection].sum() / num_chars
185
+
186
+ hidden_der = 1.0 - base_accuracy
187
+ partial_der = 1.0 - ctxt_accuracy
188
+ reader_error = selectivity * partial_der + (1 - selectivity) * hidden_der
189
+
190
+ return PartialDiacMetrics(
191
+ diff_total = round(diff_total.item() * 100, 2),
192
+ worse_total = round(worse_total.item() * 100, 2),
193
+ diff_relative = round(diff_relative.item() * 100, 2),
194
+ der_total = round(der_total.item() * 100, 2),
195
+ selectivity = round(selectivity.item() * 100, 2),
196
+ hidden_der = round(hidden_der.item() * 100, 2),
197
+ partial_der = round(partial_der.item() * 100, 2),
198
+ reader_error = round(reader_error.item() * 100, 2)
199
+ )
200
+
201
+
202
+ def relative_improvement_soft(
203
+ pred_c : T.Tensor,
204
+ pred_m : T.Tensor,
205
+ labels : T.LongTensor,
206
+ padding_mask: T.Tensor,
207
+ ):
208
+ #^ pred_c: [b tw tc Classes="15"]
209
+ #^ pred_m: [b tw tc Classes="15"]
210
+ padding_mask = T.tensor(padding_mask)
211
+ nonpad_mask = 1 - padding_mask.float()
212
+ num_chars = nonpad_mask.sum()
213
+
214
+ pred_c = T.tensor(pred_c)[~padding_mask]
215
+ pred_m = T.tensor(pred_m)[~padding_mask]
216
+ #^ : [(b * tw * tc), Classes]
217
+ labels = T.tensor(labels)[~padding_mask]
218
+ #^ : [(b * tw * tc) | ClassId]
219
+
220
+ ctxt_match = T.gather(pred_c, dim=1, index=labels)
221
+ base_match = T.gather(pred_m, dim=1, index=labels)
222
+ selection = (pred_c.argmax(-1) != pred_m.argmax(-1))
223
+
224
+ better = T.sum(ctxt_match - base_match) / num_chars
225
+ selectivity = selection.sum() / num_chars
226
+ worse = base_match[selection].sum() / num_chars
227
+ return better, worse, selectivity
228
+
229
+ def relative_improvement_masked_soft(
230
+ pred_c: T.Tensor,
231
+ pred_m: T.Tensor,
232
+ ground_truth: T.LongTensor,
233
+ padding_mask: T.Tensor,
234
+ ):
235
+ raise NotImplementedError
236
+ #^ pred_c: [b tw tc "13"]
237
+ #^ pred_m: [b tw tc "13"]
238
+ #^ ground_truth: [b tw tc ClassId]
239
+ nonpad_mask = 1 - padding_mask
240
+
241
+ selection_mask = pred_c.argmax(3) != pred_m.argmax(3)
242
+ #^ selection_mask: [b tw tc]
243
+ probs = F.softmax(pred_c.clone(), dim=-1)
244
+ probs_gt = T.gather(probs, dim=-1, index=ground_truth.unsqueeze(-1)).squeeze(-1)
245
+ #^ probs_gt: [b tw tc]
246
+ result = probs_gt[selection_mask & nonpad_mask].mean()
247
+ return result
248
+
249
+ def coverage_confidence(
250
+ pred_c: T.Tensor,
251
+ pred_m: T.Tensor,
252
+ padding_mask: T.Tensor,
253
+ # selection_mask: T.Tensor,
254
+ ):
255
+ raise NotImplementedError
256
+ #^ pred_c: [b tw tc "13"]
257
+ #^ pred_m: [b tw tc "13"]
258
+ #^ selection_mask: [b tw tc (bool)]
259
+ pred_c_id = pred_c.argmax(3)
260
+ pred_m_id = pred_m.argmax(3)
261
+ selected = pred_c_id[pred_c_id != pred_m_id]
262
+ nonpad_mask = 1 - padding_mask
263
+ result = selected.sum() / nonpad_mask.sum()
264
+ return result
265
+
266
+ def cli():
267
+ parser = ArgumentParser('Compare diacritics from base/ctxt systems with partial diac metrics.')
268
+ parser.add_argument('-m', '--model-output-base', help="Path to tensor.pt dump files of base diacs.")
269
+ parser.add_argument('-c', '--model-output-ctxt', help="Path to tensor.pt dump files of ctxt diacs.")
270
+ parser.add_argument('--gt', default=None, help="Path to tensor.pt for gt only.")
271
+ parser.add_argument('--mode', choices=['hard', 'logits'], default='hard')
272
+ args = parser.parse_args()
273
+
274
+ model_output_base = parse_data(
275
+ load_data(args.model_output_base),
276
+ # logits=args.mode == 'logits',
277
+ logits=True,
278
+ side='base',
279
+ )
280
+ model_output_ctxt = parse_data(
281
+ load_data(args.model_output_ctxt),
282
+ # logits=args.mode == 'logits',
283
+ logits=True,
284
+ side='ctxt',
285
+ )
286
+ #^ shape: [b, tc] -> ClassId
287
+ diacs_pred = model_output_base
288
+
289
+ logln(f"{model_output_base[0].shape=} , {model_output_ctxt[0].shape=}")
290
+
291
+ assert len(model_output_base[0]) == len(model_output_ctxt[0])
292
+
293
+ # for diacs_base, diacs_ctxt in zip(
294
+ # tqdm(model_output_base, dynamic_cols=True),
295
+ # model_output_ctxt
296
+ # ):
297
+ # diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0)[diacs_ctxt != -1] #< Ignore padding
298
+
299
+ xc = model_output_ctxt
300
+ xm = model_output_base
301
+ # if args.mode == 'logits':
302
+ # elif args.mode == 'hard':
303
+ # xc = model_output_ctxt
304
+ # xm = model_output_base
305
+ # if args.gt is not None:
306
+ # ground_truth = parse_data(load_data(args.gt))[1]
307
+ if xm[1] is not None:
308
+ ground_truth = xm[1]
309
+ elif xc[1] is not None:
310
+ ground_truth = xc[1]
311
+ assert ground_truth is not None
312
+
313
+ if args.mode == 'hard':
314
+ selection = make_mask_hard(xc[0], xm[0])
315
+ elif args.mode == 'logits':
316
+ selection = make_mask_logits(xc[2], xm[2])
317
+
318
+ metrics = analysis_summary(
319
+ xc[0], xm[0], ground_truth, ground_truth == -1,
320
+ selection=selection,
321
+ logits=(xc[2], xm[2])
322
+ )
323
+ logln("Actual Totals:", metrics)
324
+ metrics = analysis_summary(
325
+ xc[0], xm[0], ground_truth, ground_truth == -1, random=True,
326
+ selection=selection,
327
+ logits=(xc[2], xm[2])
328
+ )
329
+ logln("Random Marked Chars:", metrics)
predict.py CHANGED
@@ -5,7 +5,7 @@ import argparse
5
  import os
6
 
7
  import yaml
8
- from pyarabic.araby import tokenize, strip_tatweel
9
  from tqdm import tqdm
10
 
11
  import numpy as np
@@ -19,6 +19,69 @@ from data_utils import DatasetUtils
19
  from dataloader import DataRetriever
20
  from segment import segment
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class Predictor:
23
  def __init__(self, config, text):
24
 
@@ -45,8 +108,8 @@ class Predictor:
45
  if T.cuda.is_available() else 'cpu'
46
  )
47
 
48
- self.model = DiacritizerD2(config)
49
- self.model.build(word_embeddings, vocab_size)
50
  state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
51
  self.model.load_state_dict(state_dict)
52
  self.model.to(self.device)
@@ -82,6 +145,13 @@ class PredictTri(Predictor):
82
  y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader)
83
  diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
84
  return diacritized_lines
 
 
 
 
 
 
 
85
 
86
  def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
87
  assert isinstance(self.model, PartialDD)
@@ -89,7 +159,7 @@ class PredictTri(Predictor):
89
  if not os.path.exists("dataset/cache"):
90
  os.mkdir("dataset/cache")
91
  # segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True)
92
- segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='ctxt')
93
  T.save(segment_outputs, "dataset/cache/cache.pt")
94
  else:
95
  segment_outputs = T.load("dataset/cache/cache.pt")
@@ -107,6 +177,7 @@ class PredictTri(Predictor):
107
  # 'logits': segment_outputs['logits'],
108
  }
109
  }
 
110
  return diacritized_lines, extra_out
111
 
112
  def coalesce_votes_by_majority(
 
5
  import os
6
 
7
  import yaml
8
+ from pyarabic.araby import tokenize, strip_tatweel, strip_tashkeel
9
  from tqdm import tqdm
10
 
11
  import numpy as np
 
19
  from dataloader import DataRetriever
20
  from segment import segment
21
 
22
+ from partial_dd_metrics import (
23
+ parse_data,
24
+ load_data,
25
+ make_mask_hard,
26
+ make_mask_logits,
27
+ )
28
+
29
+ def apply_tashkeel(
30
+ line: str,
31
+ diacs: Union[np.ndarray, T.Tensor]
32
+ ):
33
+ line_w_diacs = ""
34
+ diacs_h3 = DatasetUtils.flat2_3head(diacs)
35
+ for ch, tashkeel in zip(line, zip(*diacs_h3)):
36
+ line_w_diacs += ch
37
+ line_w_diacs += DatasetUtils.shakkel_char(*tashkeel)
38
+ return line_w_diacs
39
+
40
+ def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1):
41
+
42
+ mode = selection_mode
43
+ if mode == 'contrastive-hard':
44
+ # model_output_base = parse_data(data_base)[0]
45
+ # model_output_ctxt = parse_data(data_ctxt)[0]
46
+ # diacs = np.where(diacs_base != diacs_ctxt, diacs_ctxt, 0)
47
+ diacritics = np.where(
48
+ make_mask_hard(model_output_ctxt, model_output_base),
49
+ model_output_ctxt.argmax(-1),
50
+ 0,
51
+ ).astype(int)
52
+ else:
53
+ # model_output_base = parse_data(data_base, logits=True, side='base')[2]
54
+ # model_output_ctxt = parse_data(data_ctxt, logits=True, side='ctxt')[2]
55
+ diacritics = np.where(
56
+ make_mask_logits(
57
+ model_output_ctxt, model_output_base,
58
+ version=mode, threshold=threshold,
59
+ ),
60
+ model_output_ctxt.argmax(-1),
61
+ 0,
62
+ ).astype(int)
63
+ #^ shape: [b, tc | ClassId]
64
+ diacs_pred = model_output_base
65
+
66
+ assert len(diacs_pred) == len(data)
67
+ data = [
68
+ ' '.join(tokenize(
69
+ line.strip(),
70
+ morphs=[strip_tashkeel, strip_tatweel]
71
+ ))
72
+ for line in data
73
+ ]
74
+
75
+ output = []
76
+ for line, line_diacs in zip(
77
+ tqdm(data),
78
+ diacritics
79
+ ):
80
+ line = apply_tashkeel(line, line_diacs)
81
+ output.append(line)
82
+
83
+ return '\n'.join(output)
84
+
85
  class Predictor:
86
  def __init__(self, config, text):
87
 
 
108
  if T.cuda.is_available() else 'cpu'
109
  )
110
 
111
+ self.model = PartialDD(config)
112
+ self.model.sentence_diac.build(word_embeddings, vocab_size)
113
  state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
114
  self.model.load_state_dict(state_dict)
115
  self.model.to(self.device)
 
145
  y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader)
146
  diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
147
  return diacritized_lines
148
+
149
+ def predict_partial(self, do_partial):
150
+ outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial)
151
+ y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics']
152
+
153
+ diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
154
+ return '\n'.join(diac_lines)
155
 
156
  def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
157
  assert isinstance(self.model, PartialDD)
 
159
  if not os.path.exists("dataset/cache"):
160
  os.mkdir("dataset/cache")
161
  # segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True)
162
+ segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='both')
163
  T.save(segment_outputs, "dataset/cache/cache.pt")
164
  else:
165
  segment_outputs = T.load("dataset/cache/cache.pt")
 
177
  # 'logits': segment_outputs['logits'],
178
  }
179
  }
180
+
181
  return diacritized_lines, extra_out
182
 
183
  def coalesce_votes_by_majority(