bkhmsi commited on
Commit
d36d50b
·
1 Parent(s): ba05666

initialized repo

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pyc
2
+ *.pt
3
+ *.vec
4
+ .DS_Store
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import gdown
4
+ import gradio as gr
5
+ from predict import PredictTri
6
+
7
+ output_path = "tashkeela-d2.pt"
8
+ if not os.path.exists(output_path):
9
+ model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
10
+ gdown.download(id=model_gdrive_id, output=output_path, quiet=False)
11
+
12
+ output_path = "vocab.vec"
13
+ if not os.path.exists(output_path):
14
+ vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
15
+ gdown.download(id=vocab_gdrive_id, output=output_path, quiet=False)
16
+
17
+ with open("config.yaml", 'r', encoding="utf-8") as file:
18
+ config = yaml.load(file, Loader=yaml.FullLoader)
19
+
20
+ config["train"]["max-sent-len"] = config["predictor"]["window"]
21
+ config["train"]["max-token-count"] = config["predictor"]["window"] * 3
22
+
23
+ def diacritze(text):
24
+ print(text)
25
+ predictor = PredictTri(config, text)
26
+ diacritized_lines = predictor.predict_majority_vote()
27
+ return '\n'.join(diacritized_lines)
28
+
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown(
31
+ """
32
+ # Partial Diacritization
33
+ TODO: put paper links here
34
+ """)
35
+ input_txt = gr.Textbox(
36
+ placeholder="اكتب هنا",
37
+ lines=5,
38
+ label="Input",
39
+ type='text',
40
+ # rtl=True,
41
+ # text_align='right',
42
+ )
43
+
44
+ output_txt = gr.Textbox(
45
+ lines=5,
46
+ label="Output",
47
+ type='text',
48
+ # rtl=True,
49
+ # text_align='right',
50
+ )
51
+
52
+ btn = gr.Button(value="Shakkel")
53
+ btn.click(diacritze, inputs=input_txt, outputs=output_txt)
54
+
55
+ if __name__ == "__main__":
56
+ demo.launch()
components/attention.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import (
2
+ Optional,
3
+ )
4
+ import math
5
+
6
+ import torch as T
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ import opt_einsum as oe
11
+
12
+ from torch import Tensor
13
+
14
+ einsum = oe.contract
15
+
16
+
17
+ def masked_softmax(xs: Tensor, mask: Tensor, dim: int = -1, eps=1e-12):
18
+ xs = xs.masked_fill(~mask, -1e9)
19
+ xs = F.softmax(xs, dim=dim)
20
+ return xs
21
+
22
+ class Attention(nn.Module):
23
+ def __init__(
24
+ self,
25
+ kind: str,
26
+ query_dim: int,
27
+ input_dim: int,
28
+ output_dim: int = None,
29
+ activation: str = 'auto',
30
+ scaled = True,
31
+ ):
32
+ super().__init__()
33
+ assert kind in [
34
+ 'dot',
35
+ 'linear',
36
+ ]
37
+
38
+ self.kind = kind
39
+ self.Dq = query_dim
40
+ self.Din = input_dim
41
+ self.Dout = output_dim or self.Din
42
+ self.activation = 'auto'
43
+ self.scaled = scaled
44
+
45
+ self.Wq_ = nn.Linear(self.Dq, self.Din)
46
+ self.Wk_ = nn.Linear(self.Din, self.Din)
47
+ self.Wv_ = nn.Linear(self.Din, self.Dout)
48
+ self.Wz_ = nn.Linear(self.Din, self.Dout)
49
+
50
+ def forward(
51
+ self,
52
+ query: Tensor,
53
+ data: Tensor,
54
+ content_mask: Optional[Tensor] = None,
55
+ prejudice_mask: Optional[Tensor] = None,
56
+ ):
57
+ #^ query: [b, ts, tw, dq]
58
+ #^ data: [b, ts, di]
59
+ #^ content_mask: [b, ts, tw]
60
+ #^ prejudice_mask: [b, ts, ts]
61
+ #^ => output: [b, ts, tw, dz]
62
+
63
+ dimB, dimS, dimW, dimI = query.shape
64
+
65
+ # TODO: Optimize out the [ts, ts, *] intermediate
66
+ qs = self.Wq_(query)
67
+ ks = self.Wk_(data)
68
+ vs = self.Wv_(data)
69
+
70
+ if content_mask is not None:
71
+ words_mask = content_mask.any(2)
72
+ #^ words_mask : [b, ts]
73
+ else:
74
+ words_mask = qs.new_ones((dimB, dimS))
75
+
76
+ if self.kind == 'linear':
77
+ # Ref: https://twitter.com/francoisfleuret/status/1267455240007188486
78
+ assert prejudice_mask is None, "Linear mode does not support prejudice_mask."
79
+ assert content_mask is not None, "Linear mode requires a content_mask."
80
+ qs = T.relu(qs) * content_mask.unsqueeze(3)
81
+ #^ qs: [bswi]
82
+ ks = T.relu(ks) * words_mask.unsqueeze(2)
83
+ #^ ks: [bsi]
84
+ vks = einsum("bsi, bsz -> bzi", ks, vs)
85
+ #^ vks : [b, dz, di]
86
+ zs = einsum("bswi, bzi -> bswz", qs, vks)
87
+ #^ zs : [b, ts, tw, dz]
88
+ if self.scaled:
89
+ ks = ks.sum(1)
90
+ #^ ks: [bi]
91
+ denom = einsum("bswi, bi -> bsw", qs, ks) + 1e-9
92
+ zs = zs / denom
93
+
94
+ elif self.kind == 'dot':
95
+ # Ref: https://arxiv.org/abs/1706.03762
96
+ # s=ts in q
97
+ # S=ts in ks,vs
98
+ att_map = einsum("bqwi, bki -> bqkw", qs, ks)
99
+ #^ [b, ts:q, ts:k, tw]
100
+ if self.scaled == 'seqlen':
101
+ att_map_ndim = len(att_map.shape) - 1
102
+ norm_coeff = words_mask.sum(1).view(-1, *([1] * att_map_ndim))
103
+ #^ [b, _, _, _]
104
+ att_map = att_map / T.sqrt(norm_coeff.float())
105
+ else:
106
+ att_map = att_map / math.sqrt(self.Din)
107
+
108
+ if content_mask is None and prejudice_mask is None:
109
+ att_map = F.softmax(att_map, dim=2)
110
+ else:
111
+ if content_mask is None:
112
+ assert prejudice_mask is not None # !for mypy
113
+ qk_mask = prejudice_mask.unsqueeze(3)
114
+ #^ qk_mask : [b, ts:q, ts:k, tw^]
115
+ elif prejudice_mask is None:
116
+ qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
117
+ #^ qk_mask : [b, ts:q, ts:k^, tw]
118
+ else:
119
+ qk_mask = words_mask.unsqueeze(1).unsqueeze(3)
120
+ # qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
121
+ qk_mask = qk_mask * prejudice_mask.unsqueeze(3)
122
+ #^ qk_mask : [b, ts:q^, ts:k, tw]
123
+
124
+ att_map = masked_softmax(att_map, qk_mask.bool(), dim=2)
125
+
126
+ #^ att_map : [b, ts:q, ts:k, tw]
127
+ zs = einsum("bqkw, bkz -> bqwz", att_map, vs)
128
+
129
+ zs = self.Wz_(zs)
130
+ return zs, att_map
components/k_lstm.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import (
2
+ Tuple,
3
+ List,
4
+ Optional,
5
+ Dict,
6
+ Callable,
7
+ Union,
8
+ cast,
9
+ )
10
+ from collections import namedtuple
11
+ from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass
13
+
14
+ import numpy as np
15
+
16
+ import torch as T
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+
20
+ from torch import Tensor
21
+
22
+ from .rnn_base import (
23
+ IRecurrentCell,
24
+ IRecurrentCellBuilder,
25
+ RecurrentLayer,
26
+ RecurrentLayerStack,
27
+ )
28
+
29
+ __all__ = [
30
+ 'K_LSTM',
31
+ 'K_LSTM_Cell',
32
+ 'K_LSTM_Cell_Builder',
33
+ ]
34
+
35
+ ACTIVATIONS = {
36
+ 'sigmoid': nn.Sigmoid(),
37
+ 'tanh': nn.Tanh(),
38
+ 'hard_tanh': nn.Hardtanh(),
39
+ 'relu': nn.ReLU(),
40
+ }
41
+
42
+ GateSpans = namedtuple('GateSpans', ['I', 'F', 'G', 'O'])
43
+
44
+ @dataclass
45
+ class K_LSTM_Cell_Builder(IRecurrentCellBuilder):
46
+ vertical_dropout : float = 0.0
47
+ recurrent_dropout : float = 0.0
48
+ recurrent_dropout_mode : str = 'gal_tied'
49
+ input_kernel_initialization : str = 'xavier_uniform'
50
+ recurrent_activation : str = 'sigmoid'
51
+ tied_forget_gate : bool = False
52
+
53
+ def make(self, input_size: int):
54
+ return K_LSTM_Cell(input_size, self)
55
+
56
+ class K_LSTM_Cell(IRecurrentCell):
57
+ def __repr__(self):
58
+ return (
59
+ f'{self.__class__.__name__}('
60
+ + ', '.join(
61
+ [
62
+ f'in: {self.Dx}',
63
+ f'hid: {self.Dh}',
64
+ f'rdo: {self.recurrent_dropout_p} @{self.recurrent_dropout_mode}',
65
+ f'vdo: {self.vertical_dropout_p}'
66
+ ]
67
+ )
68
+ +')'
69
+ )
70
+
71
+ def __init__(
72
+ self,
73
+ input_size: int,
74
+ args: K_LSTM_Cell_Builder,
75
+ ):
76
+ super().__init__()
77
+ self._args = args
78
+ self.Dx = input_size
79
+ self.Dh = args.hidden_size
80
+ self.recurrent_kernel = nn.Linear(self.Dh, self.Dh * 4)
81
+ self.input_kernel = nn.Linear(self.Dx, self.Dh * 4)
82
+
83
+ self.recurrent_dropout_p = args.recurrent_dropout or 0.0
84
+ self.vertical_dropout_p = args.vertical_dropout or 0.0
85
+ self.recurrent_dropout_mode = args.recurrent_dropout_mode
86
+
87
+ self.recurrent_dropout = nn.Dropout(self.recurrent_dropout_p)
88
+ self.vertical_dropout = nn.Dropout(self.vertical_dropout_p)
89
+
90
+ self.tied_forget_gate = args.tied_forget_gate
91
+
92
+ if isinstance(args.recurrent_activation, str):
93
+ self.fun_rec = ACTIVATIONS[args.recurrent_activation]
94
+ else:
95
+ self.fun_rec = args.recurrent_activation
96
+
97
+ self.reset_parameters_()
98
+
99
+ # @T.jit.ignore
100
+ def get_recurrent_weights(self):
101
+ # type: () -> Tuple[GateSpans, GateSpans]
102
+ W = self.recurrent_kernel.weight.chunk(4, 0)
103
+ b = self.recurrent_kernel.bias.chunk(4, 0)
104
+ W = GateSpans(W[0], W[1], W[2], W[3])
105
+ b = GateSpans(b[0], b[1], b[2], b[3])
106
+ return W, b
107
+
108
+ # @T.jit.ignore
109
+ def get_input_weights(self):
110
+ # type: () -> Tuple[GateSpans, GateSpans]
111
+ W = self.input_kernel.weight.chunk(4, 0)
112
+ b = self.input_kernel.bias.chunk(4, 0)
113
+ W = GateSpans(W[0], W[1], W[2], W[3])
114
+ b = GateSpans(b[0], b[1], b[2], b[3])
115
+ return W, b
116
+
117
+ @T.jit.ignore
118
+ def reset_parameters_(self):
119
+ rw, rb = self.get_recurrent_weights()
120
+ iw, ib = self.get_input_weights()
121
+
122
+ nn.init.zeros_(self.input_kernel.bias)
123
+ nn.init.zeros_(self.recurrent_kernel.bias)
124
+ nn.init.ones_(rb.F)
125
+ #^ forget bias
126
+
127
+ for W in rw:
128
+ nn.init.orthogonal_(W)
129
+ for W in iw:
130
+ nn.init.xavier_uniform_(W)
131
+
132
+ @T.jit.export
133
+ def get_init_state(self, input: Tensor) -> Tuple[Tensor, Tensor]:
134
+ batch_size = input.shape[1]
135
+ h0 = T.zeros(batch_size, self.Dh, device=input.device)
136
+ c0 = T.zeros(batch_size, self.Dh, device=input.device)
137
+ return (h0, c0)
138
+
139
+ def apply_input_kernel(self, xt: Tensor) -> List[Tensor]:
140
+ xto = self.vertical_dropout(xt)
141
+ out = self.input_kernel(xto).chunk(4, 1)
142
+ # return cast(List[Tensor], out)
143
+ return out
144
+
145
+ def apply_recurrent_kernel(self, h_tm1: Tensor):
146
+ #^ h_tm1 : [b h]
147
+ mode = self.recurrent_dropout_mode
148
+ if mode == 'gal_tied':
149
+ hto = self.recurrent_dropout(h_tm1)
150
+ out = self.recurrent_kernel(hto)
151
+ #^ out : [b 4h]
152
+ outs = out.chunk(4, -1)
153
+ elif mode == 'gal_gates':
154
+ outs = []
155
+ WW, bb = self.get_recurrent_weights()
156
+ for i in range(4):
157
+ hto = self.recurrent_dropout(h_tm1)
158
+ outs.append(F.linear(hto, WW[i], bb[i]))
159
+ else:
160
+ outs = self.recurrent_kernel(h_tm1).chunk(4, -1)
161
+ return outs
162
+
163
+ def forward(self, input, state):
164
+ # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
165
+ #^ input : [b i]
166
+ #^ state.h : [b h]
167
+
168
+ (h_tm1, c_tm1) = state
169
+
170
+ Xi, Xf, Xg, Xo = self.apply_input_kernel(input)
171
+ Hi, Hf, Hg, Ho = self.apply_recurrent_kernel(h_tm1)
172
+
173
+ ft = self.fun_rec(Xf + Hf)
174
+ ot = self.fun_rec(Xo + Ho)
175
+ if self.tied_forget_gate:
176
+ it = 1.0 - ft
177
+ else:
178
+ it = self.fun_rec(Xi + Hi)
179
+
180
+ gt = T.tanh(Xg + Hg) # * np.sqrt(3)
181
+ if self.recurrent_dropout_mode == 'semeniuta':
182
+ #* https://arxiv.org/abs/1603.05118
183
+ gt = self.recurrent_dropout(gt)
184
+
185
+ ct = (ft * c_tm1) + (it * gt)
186
+
187
+ ht = ot * T.tanh(ct)
188
+
189
+ return ht, (ht, ct)
190
+
191
+ @T.jit.export
192
+ def loop(self, inputs, state_t0, mask=None):
193
+ # type: (List[Tensor], Tuple[Tensor, Tensor], Optional[List[Tensor]]) -> Tuple[List[Tensor], Tuple[Tensor, Tensor]]
194
+ '''
195
+ This loops over t (time) steps
196
+ '''
197
+ #^ inputs : t * [b i]
198
+ #^ state_t0[i] : [b s]
199
+ #^ out : [t b h]
200
+ state = state_t0
201
+ outs = []
202
+ for xt in inputs:
203
+ ht, state = self(xt, state)
204
+ outs.append(ht)
205
+
206
+ return outs, state
207
+
208
+ class K_LSTM(RecurrentLayerStack):
209
+ def __init__(
210
+ self,
211
+ *args,
212
+ **kargs,
213
+ ):
214
+ builder = K_LSTM_Cell_Builder
215
+ super().__init__(
216
+ builder,
217
+ *args, **kargs
218
+ )
components/linear_scheduler.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class LinearSchedule:
2
+ def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
3
+ """Linear interpolation between initial_p and final_p over
4
+ schedule_timesteps. After this many timesteps pass final_p is
5
+ returned.
6
+ Parameters
7
+ ----------
8
+ schedule_timesteps: int
9
+ Number of timesteps for which to linearly anneal initial_p
10
+ to final_p
11
+ initial_p: float
12
+ initial output value
13
+ final_p: float
14
+ final output value
15
+ """
16
+ self.schedule_timesteps = schedule_timesteps
17
+ self.final_p = final_p
18
+ self.initial_p = initial_p
19
+
20
+ def value(self, t):
21
+ """See Schedule.value"""
22
+ fraction = min(float(t) / self.schedule_timesteps, 1.0)
23
+ return self.initial_p + fraction * (self.final_p - self.initial_p)
24
+
components/rnn.py ADDED
File without changes
components/rnn_base.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import (
2
+ Tuple,
3
+ List,
4
+ Union,
5
+ Dict,
6
+ Optional,
7
+ Callable,
8
+ )
9
+ from collections import namedtuple
10
+ from abc import ABC, abstractmethod
11
+
12
+ import torch as T
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+ from torch import Tensor
17
+
18
+ import pdb
19
+
20
+ from dataclasses import dataclass
21
+
22
+
23
+ class IRecurrentCell(ABC, nn.Module):
24
+ @abstractmethod
25
+ def get_init_state(self, input: Tensor):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def loop(self, inputs, state_t0, mask=None):
30
+ pass
31
+
32
+ # def forward(self, input, state, mask=None):
33
+ # pass
34
+
35
+ @dataclass
36
+ class IRecurrentCellBuilder(ABC):
37
+ hidden_size: int
38
+
39
+ def make(self, input_size: int) -> IRecurrentCell:
40
+ pass
41
+
42
+ def make_scripted(self, *p, **ks) -> IRecurrentCell:
43
+ return T.jit.script(self.make(*p, **ks))
44
+
45
+ class RecurrentLayer(nn.Module):
46
+ def reorder_inputs(self, inputs: Union[List[T.Tensor], T.Tensor]):
47
+ #^ inputs : [t b i]
48
+ if self.direction == 'backward':
49
+ return inputs[::-1]
50
+ return inputs
51
+
52
+ def __init__(
53
+ self,
54
+ cell: IRecurrentCell,
55
+ direction='forward',
56
+ batch_first=False,
57
+ ):
58
+ super().__init__()
59
+ if isinstance(batch_first, bool):
60
+ batch_first = (batch_first, batch_first)
61
+ self.batch_first = batch_first
62
+ self.direction = direction
63
+ self.cell_: IRecurrentCell = cell
64
+
65
+ @T.jit.ignore
66
+ def forward(self, input, state_t0, return_state=None):
67
+ if self.batch_first[0]:
68
+ #^ input : [b t i]
69
+ input = input.transpose(1, 0)
70
+ #^ input : [t b i]
71
+ inputs = input.unbind(0)
72
+
73
+ if state_t0 is None:
74
+ state_t0 = self.cell_.get_init_state(input)
75
+
76
+ inputs = self.reorder_inputs(inputs)
77
+
78
+ if return_state:
79
+ sequence, state = self.cell_.loop(inputs, state_t0)
80
+ else:
81
+ sequence, _ = self.cell_.loop(inputs, state_t0)
82
+ #^ sequence : t * [b h]
83
+ sequence = self.reorder_inputs(sequence)
84
+ sequence = T.stack(sequence)
85
+ #^ sequence : [t b h]
86
+
87
+ if self.batch_first[1]:
88
+ sequence = sequence.transpose(1, 0)
89
+ #^ sequence : [b t h]
90
+
91
+ if return_state:
92
+ return sequence, state
93
+ else:
94
+ return sequence, None
95
+
96
+ class BidirectionalRecurrentLayer(nn.Module):
97
+ def __init__(
98
+ self,
99
+ input_size: int,
100
+ cell_builder: IRecurrentCellBuilder,
101
+ batch_first=False,
102
+ return_states=False
103
+ ):
104
+ super().__init__()
105
+ self.batch_first = batch_first
106
+ self.cell_builder = cell_builder
107
+ self.batch_first = batch_first
108
+ self.return_states = return_states
109
+ self.fwd = RecurrentLayer(
110
+ cell_builder.make_scripted(input_size),
111
+ direction='forward',
112
+ batch_first=batch_first
113
+ )
114
+ self.bwd = RecurrentLayer(
115
+ cell_builder.make_scripted(input_size),
116
+ direction='backward',
117
+ batch_first=batch_first
118
+ )
119
+
120
+ @T.jit.ignore
121
+ def forward(self, input, state_t0, is_last):
122
+ return_states = is_last and self.return_states
123
+ if return_states:
124
+ fwd, state_fwd = self.fwd(input, state_t0, return_states)
125
+ bwd, state_bwd = self.bwd(input, state_t0, return_states)
126
+ return T.cat([fwd, bwd], dim=-1), (T.cat([state_fwd[0], state_bwd[0]], dim=-1), T.cat([state_fwd[1], state_bwd[1]], dim=-1))
127
+ else:
128
+ fwd, _ = self.fwd(input, state_t0, return_states)
129
+ bwd, _ = self.bwd(input, state_t0, return_states)
130
+ return T.cat([fwd, bwd], dim=-1), None
131
+
132
+ class RecurrentLayerStack(nn.Module):
133
+ def __init__(
134
+ self,
135
+ cell_builder : Callable[..., IRecurrentCellBuilder],
136
+ input_size : int,
137
+ num_layers : int,
138
+ bidirectional : bool = False,
139
+ batch_first : bool = False,
140
+ scripted : bool = True,
141
+ return_states : bool = False,
142
+ *args, **kargs,
143
+ ):
144
+ super().__init__()
145
+ cell_builder_: IRecurrentCellBuilder = cell_builder(*args, **kargs)
146
+ self._cell_builder = cell_builder_
147
+
148
+ if bidirectional:
149
+ Dh = cell_builder_.hidden_size * 2
150
+ def make(isize: int, last=False):
151
+ return BidirectionalRecurrentLayer(isize, cell_builder_,
152
+ batch_first=batch_first, return_states=return_states)
153
+ else:
154
+ Dh = cell_builder_.hidden_size
155
+ def make(isize: int, last=False):
156
+ cell = cell_builder_.make_scripted(isize)
157
+ return RecurrentLayer(cell, isize,
158
+ batch_first=batch_first)
159
+
160
+
161
+ if num_layers > 1:
162
+ rnns = [
163
+ make(input_size),
164
+ *[
165
+ make(Dh)
166
+ for _ in range(num_layers - 2)
167
+ ],
168
+ make(Dh, last=True)
169
+ ]
170
+ else:
171
+ rnns = [make(input_size, last=True)]
172
+
173
+ self.rnn = nn.Sequential(*rnns)
174
+
175
+ self.input_size = input_size
176
+ self.hidden_size = self._cell_builder.hidden_size
177
+ self.num_layers = num_layers
178
+ self.bidirectional = bidirectional
179
+ self.return_states = return_states
180
+
181
+ def __repr__(self):
182
+ return (
183
+ f'${self.__class__.__name__}'
184
+ + '('
185
+ + f'in={self.input_size}, '
186
+ + f'hid={self.hidden_size}, '
187
+ + f'layers={self.num_layers}, '
188
+ + f'bi={self.bidirectional}'
189
+ + '; '
190
+ + str(self._cell_builder)
191
+ )
192
+
193
+ def forward(self, input, state_t0=None):
194
+ for layer_idx, rnn in enumerate(self.rnn):
195
+ is_last = (layer_idx == (len(self.rnn) - 1))
196
+ input, state = rnn(input, state_t0, is_last)
197
+ if self.return_states:
198
+ return input, state
199
+ return input
config.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run-title: tashkeela-d2
2
+ debug: false
3
+
4
+ paths:
5
+ base: ./dataset/ashaar
6
+ save: ./models
7
+ load: tashkeela-d2.pt
8
+ resume: ./models/Tashkeela-D2/tashkeela-d2.pt
9
+ constants: ./dataset/helpers/constants
10
+ word-embs: vocab.vec
11
+ test: test
12
+
13
+ loader:
14
+ wembs-limit: -1
15
+ num-workers: 0
16
+
17
+ train:
18
+ epochs: 1000
19
+ batch-size: 32
20
+ char-embed-dim: 32
21
+ resume: false
22
+ resume-lr: false
23
+
24
+ max-word-len: 13
25
+ max-sent-len: 10
26
+
27
+ rnn-cell: lstm
28
+ sent-lstm-layers: 2
29
+ word-lstm-layers: 2
30
+
31
+ sent-lstm-units: 256
32
+ word-lstm-units: 512
33
+ decoder-units: 256
34
+
35
+ sent-dropout: 0.2
36
+ diac-dropout: 0
37
+ final-dropout: 0.2
38
+
39
+ sent-mask-zero: false
40
+
41
+ lr-factor: 0.5
42
+ lr-patience: 1
43
+ lr-min: 1.e-7
44
+ lr-init: 0.002
45
+
46
+ weight-decay: 0
47
+ vertical-dropout: 0.25
48
+ recurrent-dropout: 0.25
49
+
50
+ stopping-delta: 1.e-7
51
+ stopping-patience: 3
52
+
53
+ predictor:
54
+ batch-size: 75
55
+ stride: 2
56
+ window: 20
57
+ gt-signal-prob: 0
58
+ seed-idx: 0
59
+
60
+ sentence-break:
61
+ stride: 2
62
+ window: 10
63
+ min-window: 1
64
+ export-map: false
65
+ files:
66
+ - train/train.txt
67
+ - val/val.txt
68
+ delimeters:
69
+ - ،
70
+ - ؛
71
+ - ','
72
+ - ;
73
+ - «
74
+ - »
75
+ - '{'
76
+ - '}'
77
+ - '('
78
+ - ')'
79
+ - '['
80
+ - ']'
81
+ - '.'
82
+ - '*'
83
+ - '-'
84
+ - ':'
85
+ - '?'
86
+ - '!'
87
+ - ؟
88
+
89
+
90
+ segment:
91
+ stride: 2
92
+ window: 10
93
+ min-window: 1
94
+ export-map: false
95
+ files:
96
+ - train/train.txt
97
+ - val/val.txt
98
+ delimeters:
99
+ - ،
100
+ - ؛
101
+ - ','
102
+ - ;
103
+ - «
104
+ - »
105
+ - '{'
106
+ - '}'
107
+ - '('
108
+ - ')'
109
+ - '['
110
+ - ']'
111
+ - '.'
112
+ - '*'
113
+ - '-'
114
+ - ':'
115
+ - '?'
116
+ - '!'
117
+ - ؟
data_utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import numpy as np
4
+
5
+ from tqdm import tqdm
6
+ from prettytable import PrettyTable
7
+ from pyarabic.araby import tokenize, strip_tashkeel
8
+ import diac_utils as du
9
+
10
+ class DatasetUtils:
11
+ def __init__(self, config):
12
+ self.base_path = config["paths"]["base"]
13
+ self.special_tokens = ['<pad>', '<unk>', '<num>', '<punc>']
14
+ self.delimeters = config["sentence-break"]["delimeters"]
15
+ self.load_constants(config["paths"]["constants"])
16
+ self.debug = config["debug"]
17
+
18
+ self.stride = config["sentence-break"]["stride"]
19
+ self.window = config["sentence-break"]["window"]
20
+ self.val_stride = config["sentence-break"].get("val-stride", self.stride)
21
+
22
+ self.test_stride = config["predictor"]["stride"]
23
+ self.test_window = config["predictor"]["window"]
24
+
25
+ self.max_word_len = config["train"]["max-word-len"]
26
+ self.max_sent_len = config["train"]["max-sent-len"]
27
+ self.max_token_count = config["train"]["max-token-count"]
28
+ self.pad_target_val = -100
29
+ self.pad_char_id = du.LETTER_LIST.index('<pad>')
30
+
31
+ self.markov_signal = config['train'].get('markov-signal', False)
32
+ self.batch_first = config['train'].get('batch-first', True)
33
+
34
+ self.gt_prob = config["predictor"]["gt-signal-prob"]
35
+ if self.gt_prob > 0:
36
+ self.s_idx = config["predictor"]["seed-idx"]
37
+ subpath = f"test_gt_mask_{self.gt_prob}_{self.s_idx}.txt"
38
+ mask_path = os.path.join(self.base_path, "test", subpath)
39
+ with open(mask_path, 'r') as fin:
40
+ self.gt_mask = fin.readlines()
41
+
42
+ if "word-embs" in config["paths"] and config["paths"]["word-embs"].strip() != "":
43
+ self.pad_val = self.special_tokens.index("<pad>")
44
+ self.embeddings, self.vocab = self.load_embeddings(config["paths"]["word-embs"], config["loader"]["wembs-limit"])
45
+ self.embeddings = self.normalize(self.embeddings, ["unit", "centeremb", "unit"])
46
+ self.w2idx = {word: i for i, word in enumerate(self.vocab)}
47
+
48
+ def load_file(self, path):
49
+ with open(path, 'rb') as f:
50
+ return list(pickle.load(f))
51
+
52
+ def normalize(self, matrix, actions, mean=None):
53
+ def length_normalize(matrix):
54
+ norms = np.sqrt(np.sum(matrix**2, axis=1))
55
+ norms[norms == 0] = 1
56
+ matrix = matrix / norms[:, np.newaxis]
57
+ return matrix
58
+
59
+ def mean_center(matrix):
60
+ return matrix - mean
61
+
62
+ def length_normalize_dimensionwise(matrix):
63
+ norms = np.sqrt(np.sum(matrix**2, axis=0))
64
+ norms[norms == 0] = 1
65
+ matrix = matrix / norms
66
+ return matrix
67
+
68
+ def mean_center_embeddingwise(matrix):
69
+ avg = np.mean(matrix, axis=1)
70
+ matrix = matrix - avg[:, np.newaxis]
71
+ return matrix
72
+
73
+ for action in actions:
74
+ if action == 'unit':
75
+ matrix = length_normalize(matrix)
76
+ elif action == 'center':
77
+ matrix = mean_center(matrix)
78
+ elif action == 'unitdim':
79
+ matrix = length_normalize_dimensionwise(matrix)
80
+ elif action == 'centeremb':
81
+ matrix = mean_center_embeddingwise(matrix)
82
+
83
+ return matrix
84
+
85
+ def load_constants(self, path):
86
+ # self.numbers = [c for c in "0123456789"]
87
+ # self.letter_list = self.special_tokens + self.load_file(os.path.join(path, 'ARABIC_LETTERS_LIST.pickle'))
88
+ # self.diacritic_list = [' '] + self.load_file(os.path.join(path, 'DIACRITICS_LIST.pickle'))
89
+ self.numbers = du.NUMBERS
90
+ self.letter_list = du.LETTER_LIST
91
+ self.diacritic_list = du.DIACRITICS_SHORT
92
+
93
+ def split_word_on_characters_with_diacritics(self, word: str):
94
+ return du.split_word_on_characters_with_diacritics(word)
95
+
96
+ def load_mapping_v3(self, dtype, file_ext=None):
97
+ mapping = {}
98
+ if file_ext is None:
99
+ file_ext = f"-{self.test_stride}-{self.test_window}.map"
100
+ f_name = os.path.join(self.base_path, dtype, dtype + file_ext)
101
+ with open(f_name, 'r') as fin:
102
+ for line in fin:
103
+ sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(','))
104
+ if sent_idx not in mapping:
105
+ mapping[sent_idx] = {}
106
+ if seg_idx not in mapping[sent_idx]:
107
+ mapping[sent_idx][seg_idx] = {}
108
+ if t_idx not in mapping[sent_idx][seg_idx]:
109
+ mapping[sent_idx][seg_idx][t_idx] = []
110
+ mapping[sent_idx][seg_idx][t_idx] += [c_idx]
111
+ return mapping
112
+
113
+ def load_mapping_v3_from_list(self, mapping_list):
114
+ mapping = {}
115
+ for line in mapping_list:
116
+ sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(','))
117
+ if sent_idx not in mapping:
118
+ mapping[sent_idx] = {}
119
+ if seg_idx not in mapping[sent_idx]:
120
+ mapping[sent_idx][seg_idx] = {}
121
+ if t_idx not in mapping[sent_idx][seg_idx]:
122
+ mapping[sent_idx][seg_idx][t_idx] = []
123
+ mapping[sent_idx][seg_idx][t_idx] += [c_idx]
124
+ return mapping
125
+
126
+ def load_embeddings(self, embs_path, limit=-1):
127
+ if self.debug:
128
+ return np.zeros((200+len(self.special_tokens),300)), self.special_tokens + ["c"] * 200
129
+
130
+ words = [self.special_tokens[0]]
131
+ print(f"[INFO] Reading Embeddings from {embs_path}")
132
+ with open(embs_path, encoding='utf-8', mode='r') as fin:
133
+ n, d = map(int, fin.readline().split())
134
+ limit = n if limit <= 0 else limit
135
+ embeddings = np.zeros((limit+1, d))
136
+ for i, line in tqdm(enumerate(fin), total=limit):
137
+ if i >= limit: break
138
+ tokens = line.rstrip().split()
139
+ words += [tokens[0]]
140
+ embeddings[i+1] = list(map(float, tokens[1:]))
141
+ return embeddings, words
142
+
143
+ def load_file_clean(self, dtype, strip=False):
144
+ f_name = os.path.join(self.base_path, dtype, dtype + ".txt")
145
+ with open(f_name, 'r', encoding="utf-8", newline='\n') as fin:
146
+ if strip:
147
+ original_lines = [strip_tashkeel(self.preprocess(line)) for line in fin.readlines()]
148
+ else:
149
+ original_lines = [self.preprocess(line) for line in fin.readlines()]
150
+ return original_lines
151
+
152
+ def preprocess(self, line):
153
+ return ' '.join(tokenize(line))
154
+
155
+ def pad_and_truncate_sequence(self, tokens, max_len, pad=None):
156
+ if pad is None:
157
+ pad = self.special_tokens.index("<pad>")
158
+ if len(tokens) < max_len:
159
+ offset = max_len - len(tokens)
160
+ return tokens + [pad] * offset
161
+ else:
162
+ return tokens[:max_len]
163
+
164
+ def stats(self, freq, percentile=90, name="stats"):
165
+ table = PrettyTable(["Dataset", "Mean", "Std", "Min", "Max", f"{percentile}th Percentile"])
166
+ freq = np.array(sorted(freq))
167
+ table.add_row([name, freq.mean(), freq.std(), freq.min(), freq.max(), np.percentile(freq, percentile)])
168
+ print(table)
169
+
170
+ def create_gt_mask(self, lines, prob, idx, seed=1111):
171
+ np.random.seed(seed)
172
+
173
+ gt_masks = []
174
+ for line in lines:
175
+ tokens = tokenize(line.strip())
176
+ gt_mask_token = ""
177
+ for t_idx, token in enumerate(tokens):
178
+ gt_mask_token += ''.join(map(str, np.random.binomial(1, prob, len(token))))
179
+ if t_idx+1 < len(tokens):
180
+ gt_mask_token += " "
181
+ gt_masks += [gt_mask_token]
182
+
183
+ subpath = f"test_gt_mask_{prob}_{idx}.txt"
184
+ mask_path = os.path.join(self.base_path, "test", subpath)
185
+
186
+ with open(mask_path, 'w') as fout:
187
+ fout.write('\n'.join(gt_masks))
188
+
189
+ def create_gt_labels(self, lines):
190
+ gt_labels = []
191
+ for line in lines:
192
+ gt_labels_line = []
193
+ tokens = tokenize(line.strip())
194
+ for w_idx, word in enumerate(tokens):
195
+ split_word = self.split_word_on_characters_with_diacritics(word)
196
+ _, cy_flat, _ = du.create_label_for_word(split_word)
197
+
198
+ gt_labels_line.extend(cy_flat)
199
+ if w_idx+1 < len(tokens):
200
+ gt_labels_line += [0]
201
+
202
+ gt_labels += [gt_labels_line]
203
+ return gt_labels
204
+
205
+ def get_ce(self, diac_word_y, e_idx=None, return_idx=False):
206
+ #^ diac_word_y: [Tw 3]
207
+ if e_idx is None: e_idx = len(diac_word_y)
208
+ for c_idx in reversed(range(e_idx)):
209
+ if diac_word_y[c_idx] != [0,0,0]:
210
+ return diac_word_y[c_idx] if not return_idx else c_idx
211
+ return diac_word_y[e_idx-1] if not return_idx else e_idx-1
212
+
213
+ def create_decoder_input(self, diac_code_y, prob=0):
214
+ #^ diac_code_y: [Ts Tw 3]
215
+ diac_code_x = np.zeros((*np.array(diac_code_y).shape[:-1], 8))
216
+ if not self.markov_signal:
217
+ return list(diac_code_x)
218
+ prev_ce = list(np.eye(6)[-1]) + [0,0] # bos tag
219
+ for w_idx, word in enumerate(diac_code_y):
220
+ diac_code_x[w_idx, 0, :] = prev_ce
221
+ for c_idx, char in enumerate(word[:-1]):
222
+ # if np.random.rand() < prob:
223
+ # continue
224
+ if char[0] == self.pad_target_val:
225
+ break
226
+ haraka = list(np.eye(6)[char[0]])
227
+ diac_code_x[w_idx, c_idx+1, :] = haraka + char[1:]
228
+ ce = self.get_ce(diac_code_y[w_idx], c_idx)
229
+ prev_ce = list(np.eye(6)[ce[0]]) + ce[1:]
230
+ return list(diac_code_x)
dataloader.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from pyarabic.araby import tokenize, strip_tashkeel
4
+
5
+ import numpy as np
6
+ import torch as T
7
+ from torch.utils.data import Dataset
8
+
9
+ from data_utils import DatasetUtils
10
+ import diac_utils as du
11
+
12
+ class DataRetriever(Dataset):
13
+ def __init__(self, data_utils : DatasetUtils, lines: list):
14
+ super(DataRetriever).__init__()
15
+
16
+ self.data_utils = data_utils
17
+ self.lines = lines
18
+
19
+ def preprocess(self, data, dtype=T.long):
20
+ return [T.tensor(np.array(x), dtype=dtype) for x in data]
21
+
22
+ def __len__(self):
23
+ return len(self.lines)
24
+
25
+ def __getitem__(self, idx):
26
+ word_x, char_x, diac_x, diac_y = self.create_sentence(idx)
27
+ return self.preprocess((word_x, char_x, diac_x)), T.tensor(diac_y, dtype=T.long), T.tensor(diac_y, dtype=T.long)
28
+
29
+ def create_sentence(self, idx):
30
+ line = self.lines[idx]
31
+ tokens = tokenize(line.strip())
32
+
33
+ word_x = []
34
+ char_x = []
35
+ diac_x = []
36
+ diac_y = []
37
+ diac_y_tmp = []
38
+
39
+ for word in tokens:
40
+ word = du.strip_unknown_tashkeel(word)
41
+ word_chars = du.split_word_on_characters_with_diacritics(word)
42
+ cx, cy, cy_3head = du.create_label_for_word(word_chars)
43
+
44
+ word_strip = strip_tashkeel(word)
45
+ word_x += [self.data_utils.w2idx[word_strip] if word_strip in self.data_utils.w2idx else self.data_utils.w2idx["<pad>"]]
46
+
47
+ char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.data_utils.max_word_len)]
48
+
49
+ diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.data_utils.max_word_len, pad=self.data_utils.pad_target_val)]
50
+ diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.data_utils.max_word_len, pad=[self.data_utils.pad_target_val]*3)]
51
+
52
+ diac_x = self.data_utils.create_decoder_input(diac_y_tmp)
53
+
54
+ max_slen = self.data_utils.max_sent_len
55
+ max_wlen = self.data_utils.max_word_len
56
+ p_val = self.data_utils.pad_val
57
+ pt_val = self.data_utils.pad_target_val
58
+
59
+ word_x = self.data_utils.pad_and_truncate_sequence(word_x, max_slen)
60
+ char_x = self.data_utils.pad_and_truncate_sequence(char_x, max_slen, pad=[p_val]*max_wlen)
61
+ diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, max_slen, pad=[[p_val]*8]*max_wlen)
62
+ diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, max_slen, pad=[pt_val]*max_wlen)
63
+
64
+ return word_x, char_x, diac_x, diac_y
diac_utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch as T
4
+ import numpy as np
5
+
6
+ from pyarabic.araby import (
7
+ tokenize,
8
+ strip_tashkeel,
9
+ strip_tatweel,
10
+ DIACRITICS
11
+ )
12
+
13
+ SEPARATE_DIACRITICS = {
14
+ "FATHA": 1,
15
+ "KASRA": 2,
16
+ "DAMMA": 3,
17
+ "SUKUN": 4
18
+ }
19
+
20
+ HARAKAT_MAP = [
21
+ #^ (haraka, tanween, shadda)
22
+ (0,0,0), #< No diacs on char
23
+ (1,0,0),
24
+ (1,1,0), #< Tanween on 2nd slot
25
+ (2,0,0),
26
+ (2,1,0),
27
+ (3,0,0),
28
+ (3,1,0),
29
+ (4,0,0),
30
+ (0,0,1), #< shadda on 3rd slot
31
+ (1,0,1),
32
+ (1,1,1),
33
+ (2,0,1),
34
+ (2,1,1),
35
+ (3,0,1),
36
+ (3,1,1),
37
+ (0,0,0), #< Padding == -1 (also for spaces)
38
+ ]
39
+
40
+ SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
41
+ LETTER_LIST = SPECIAL_TOKENS + list("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىي")
42
+ CLASSES_LIST = [' ', 'َ', 'ً', 'ُ', 'ٌ', 'ِ', 'ٍ', 'ْ', 'ّ', 'َّ', 'ًّ', 'ُّ', 'ٌّ', 'ِّ', 'ٍّ']
43
+ DIACRITICS_SHORT = [' ', 'َ', 'ً', 'ِ', 'ٍ', 'ُ', 'ٌ', 'ْ', 'ّ']
44
+ NUMBERS = list("0123456789")
45
+ DELIMITERS = ["،","؛",",",";","«","»","{","}","(",")","[","]",".","*","-",":","?","!","؟"]
46
+
47
+ UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT)))
48
+
49
+ def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str:
50
+ returned_text = ""
51
+ if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]:
52
+ returned_text += "\u0651"
53
+
54
+ if diac == SEPARATE_DIACRITICS["FATHA"]:
55
+ returned_text += "\u064E" if not tanween else "\u064B"
56
+ elif diac == SEPARATE_DIACRITICS["KASRA"]:
57
+ returned_text += "\u0650" if not tanween else "\u064D"
58
+ elif diac == SEPARATE_DIACRITICS["DAMMA"]:
59
+ returned_text += "\u064F" if not tanween else "\u064C"
60
+ elif diac == SEPARATE_DIACRITICS["SUKUN"]:
61
+ returned_text += "\u0652"
62
+
63
+ return returned_text
64
+
65
+ def diac_ids_of_line(line: str):
66
+ words = tokenize(line)
67
+ diacs = []
68
+ for word in words:
69
+ word_chars = split_word_on_characters_with_diacritics(word)
70
+ cx, cy, cy_3head = create_label_for_word(word_chars)
71
+ diacs.extend(cy)
72
+ diacs.append(-1)
73
+ return np.array(diacs[:-1])
74
+
75
+ def strip_unknown_tashkeel(word: str):
76
+ #! FIXME! warnings.warn("Stripping unknown tashkeel is disabled.")
77
+ return word
78
+ return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)
79
+
80
+ def split_word_on_characters_with_diacritics(word: str):
81
+ '''
82
+ TODO! Make faster without deque and looping
83
+ Returns: List[List[char: "letter or diacritic"]]
84
+ '''
85
+ chars_w_diac = []
86
+ i_start = 0
87
+ for i_c, c in enumerate(word):
88
+ #! FIXME! DIACRITICS_SHORT is missing a lot of less common diacritics ...
89
+ #! which are then treated as letters during splitting.
90
+ # if c not in DIACRITICS:
91
+ if c not in DIACRITICS_SHORT:
92
+ sub = list(word[i_start:i_c])
93
+ chars_w_diac.append(sub)
94
+ i_start = i_c
95
+ sub = list(word[i_start:])
96
+ if sub:
97
+ chars_w_diac.append(sub)
98
+ if not chars_w_diac[0]:
99
+ chars_w_diac = chars_w_diac[1:]
100
+ return chars_w_diac
101
+
102
+
103
+ def char_type(char: str):
104
+ if char in LETTER_LIST:
105
+ return LETTER_LIST.index(char)
106
+ elif char in NUMBERS:
107
+ return LETTER_LIST.index('<num>')
108
+ elif char in DELIMITERS:
109
+ return LETTER_LIST.index('<punc>')
110
+ else:
111
+ return LETTER_LIST.index('<unk>')
112
+
113
+ def create_labels(char_w_diac: str):
114
+ remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4}
115
+ char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:]))
116
+ if len(char_w_diac) > 3:
117
+ char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3]
118
+
119
+ char_idx = None
120
+ diacritic_index = None
121
+ head_3 = None
122
+
123
+ char_idx = char_type(char_w_diac[0])
124
+ diacs = set(char_w_diac[1:])
125
+ diac_h3 = [0, 0, 0]
126
+ for diac in diacs:
127
+ if diac in DIACRITICS_SHORT:
128
+ diac_idx = DIACRITICS_SHORT.index(diac)
129
+ if diac_idx in [2, 4, 6]: #< Tanween
130
+ diac_h3[0] = remap_dict[diac_idx - 1]
131
+ diac_h3[1] = 1
132
+ elif diac_idx == 8: #< shadda
133
+ diac_h3[2] = 1
134
+ else: #< Haraka or sukoon
135
+ diac_h3[0] = remap_dict[diac_idx]
136
+ assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2]))
137
+ diacritic_index = HARAKAT_MAP.index(tuple(diac_h3))
138
+ return char_idx, diacritic_index, diac_h3
139
+ if len(char_w_diac) == 1:
140
+ return char_idx, 0, [remap_dict[0], 0, 0]
141
+ elif len(char_w_diac) == 2: # If shadda OR diac
142
+ diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])
143
+ if diacritic_index in [2, 4, 6]: # list of tanween
144
+ head_3 = [remap_dict[diacritic_index - 1], 1, 0]
145
+ elif diacritic_index == 8:
146
+ head_3 = [0, 0, 1]
147
+ else:
148
+ head_3 = [remap_dict[diacritic_index], 0, 0]
149
+ elif len(char_w_diac) == 3: # If shadda AND diac
150
+ if DIACRITICS_SHORT[8] == char_w_diac[1]:
151
+ diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2])
152
+ else:
153
+ diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])
154
+
155
+ if diacritic_index in [2, 4, 6]: # list of tanween
156
+ head_3 = [remap_dict[diacritic_index - 1], 1, 1]
157
+ else:
158
+ head_3 = [remap_dict[diacritic_index], 0, 1]
159
+ diacritic_index = diacritic_index+8
160
+
161
+ return char_idx, diacritic_index, head_3
162
+
163
+ def create_label_for_word(split_word: List[List[str]]):
164
+ word_char_indices = []
165
+ word_diac_indices = []
166
+ word_diac_indices_h3 = []
167
+ for char_w_diac in split_word:
168
+ char_idx, diac_idx, diac_h3 = create_labels(char_w_diac)
169
+ if char_idx == None:
170
+ print(split_word)
171
+ raise ValueError(char_idx)
172
+ word_char_indices.append(char_idx)
173
+ word_diac_indices.append(diac_idx)
174
+ word_diac_indices_h3.append(diac_h3)
175
+ return word_char_indices, word_diac_indices, word_diac_indices_h3
176
+
177
+
178
+ def flat_2_3head(output: T.Tensor):
179
+ '''
180
+ output: [b tw tc]
181
+ '''
182
+ haraka, tanween, shadda = [], [], []
183
+
184
+ # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
185
+ # 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD
186
+
187
+ b, ts, tw = output.shape
188
+
189
+ for b_idx in range(b):
190
+ h_s, t_s, s_s = [], [], []
191
+ for w_idx in range(ts):
192
+ h_w, t_w, s_w = [], [], []
193
+ for c_idx in range(tw):
194
+ c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])]
195
+ h_w += [c[0]]
196
+ t_w += [c[1]]
197
+ s_w += [c[2]]
198
+ h_s += [h_w]
199
+ t_s += [t_w]
200
+ s_s += [s_w]
201
+
202
+ haraka += [h_s]
203
+ tanween += [t_s]
204
+ shadda += [s_s]
205
+
206
+
207
+ return haraka, tanween, shadda
208
+
209
+ def flat2_3head(diac_idx):
210
+ '''
211
+ diac_idx: [tw]
212
+ '''
213
+ haraka, tanween, shadda = [], [], []
214
+ # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
215
+ # 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD
216
+
217
+ for diac in diac_idx:
218
+ c_out = HARAKAT_MAP[diac]
219
+ haraka += [c_out[0]]
220
+ tanween += [c_out[1]]
221
+ shadda += [c_out[2]]
222
+
223
+ return np.array(haraka), np.array(tanween), np.array(shadda)
model_dd.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as T
3
+
4
+ from tqdm import tqdm
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from components.k_lstm import K_LSTM
9
+ from components.attention import Attention
10
+ from data_utils import DatasetUtils
11
+ from diac_utils import flat2_3head, flat_2_3head
12
+
13
+ class DiacritizerD2(nn.Module):
14
+ def __init__(self, config):
15
+ super(DiacritizerD2, self).__init__()
16
+ self.max_word_len = config["train"]["max-word-len"]
17
+ self.max_sent_len = config["train"]["max-sent-len"]
18
+ self.char_embed_dim = config["train"]["char-embed-dim"]
19
+
20
+ self.final_dropout_p = config["train"]["final-dropout"]
21
+ self.sent_dropout_p = config["train"]["sent-dropout"]
22
+ self.diac_dropout_p = config["train"]["diac-dropout"]
23
+ self.vertical_dropout = config['train']['vertical-dropout']
24
+ self.recurrent_dropout = config['train']['recurrent-dropout']
25
+ self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied')
26
+ self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid')
27
+
28
+ self.sent_lstm_units = config["train"]["sent-lstm-units"]
29
+ self.word_lstm_units = config["train"]["word-lstm-units"]
30
+ self.decoder_units = config["train"]["decoder-units"]
31
+
32
+ self.sent_lstm_layers = config["train"]["sent-lstm-layers"]
33
+ self.word_lstm_layers = config["train"]["word-lstm-layers"]
34
+
35
+ self.cell = config['train'].get('rnn-cell', 'lstm')
36
+ self.num_layers = config["train"].get("num-layers", 2)
37
+ self.RNN_Layer = K_LSTM
38
+
39
+ self.batch_first = config['train'].get('batch-first', True)
40
+ self.device = 'cuda' if T.cuda.is_available() else 'cpu'
41
+ self.num_classes = 15
42
+
43
+ def build(self, wembs: T.Tensor, abjad_size: int):
44
+ self.closs = F.cross_entropy
45
+ self.bloss = F.binary_cross_entropy_with_logits
46
+
47
+ rnn_kargs = dict(
48
+ recurrent_dropout_mode=self.recurrent_dropout_mode,
49
+ recurrent_activation=self.recurrent_activation,
50
+ )
51
+
52
+ self.sent_lstm = self.RNN_Layer(
53
+ input_size=300,
54
+ hidden_size=self.sent_lstm_units,
55
+ num_layers=self.sent_lstm_layers,
56
+ bidirectional=True,
57
+ vertical_dropout=self.vertical_dropout,
58
+ recurrent_dropout=self.recurrent_dropout,
59
+ batch_first=self.batch_first,
60
+ **rnn_kargs,
61
+ )
62
+
63
+ self.word_lstm = self.RNN_Layer(
64
+ input_size=self.sent_lstm_units * 2 + self.char_embed_dim,
65
+ hidden_size=self.word_lstm_units,
66
+ num_layers=self.word_lstm_layers,
67
+ bidirectional=True,
68
+ vertical_dropout=self.vertical_dropout,
69
+ recurrent_dropout=self.recurrent_dropout,
70
+ batch_first=self.batch_first,
71
+ return_states=True,
72
+ **rnn_kargs,
73
+ )
74
+
75
+ self.char_embs = nn.Embedding(
76
+ abjad_size,
77
+ self.char_embed_dim,
78
+ padding_idx=0,
79
+ )
80
+
81
+ self.attention = Attention(
82
+ kind="dot",
83
+ query_dim=self.word_lstm_units * 2,
84
+ input_dim=self.sent_lstm_units * 2,
85
+ )
86
+
87
+ self.word_embs = T.tensor(wembs).clone().to(dtype=T.float32)
88
+ self.word_embs = self.word_embs.to(self.device)
89
+
90
+ self.classifier = nn.Linear(self.attention.Dout + self.word_lstm_units * 2, self.num_classes)
91
+ self.dropout = nn.Dropout(self.final_dropout_p)
92
+
93
+ def forward(self, sents, words, labels=None, subword_lengths=None):
94
+ #^ sents : [b ts]
95
+ #^ words : [b ts tw]
96
+ #^ labels: [b ts tw]
97
+ max_words = min(self.max_sent_len, sents.shape[1])
98
+
99
+ word_mask = words.ne(0.).float()
100
+ #^ word_mask: [b ts tw]
101
+
102
+ if self.training:
103
+ q = 1.0 - self.sent_dropout_p
104
+ sdo = T.bernoulli(T.full(sents.shape, q))
105
+ sents_do = sents * sdo.long()
106
+ #^ sents_do : [b ts] ; DO(ts)
107
+ wembs = self.word_embs[sents_do]
108
+ #^ wembs : [b ts dw] ; DO(ts)
109
+ else:
110
+ wembs = self.word_embs[sents]
111
+ #^ wembs : [b ts dw]
112
+
113
+ sent_enc = self.sent_lstm(wembs.to(self.device))
114
+ #^ sent_enc : [b ts dwe]
115
+
116
+ sentword_do = sent_enc.unsqueeze(2)
117
+ #^ sentword_do : [b ts _ dwe]
118
+
119
+ sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1))
120
+ #^ sentword_do : [b ts tw dwe]
121
+
122
+ word_index = words.view(-1, self.max_word_len)
123
+ #^ word_index: [b*ts tw]?
124
+
125
+ cembs = self.char_embs(word_index)
126
+ #^ cembs : [b*ts tw dc]
127
+
128
+ sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2)
129
+ #^ sentword_do : [b*ts tw dwe]
130
+
131
+ char_embs = T.cat([cembs, sentword_do], dim=-1)
132
+ #^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe
133
+
134
+ char_enc, _ = self.word_lstm(char_embs)
135
+ #^ char_enc: [b*ts tw dce]
136
+
137
+ char_enc_reshaped = char_enc.view(-1, max_words, self.max_word_len, self.word_lstm_units * 2)
138
+ # #^ char_enc: [b ts tw dce]
139
+
140
+ omit_self_mask = (1.0 - T.eye(max_words)).unsqueeze(0).to(self.device)
141
+ attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask)
142
+ # # #^ attn_enc: [b ts tw dae]
143
+
144
+ attn_enc = attn_enc.reshape(-1, self.max_word_len, self.attention.Dout)
145
+ # #^ attn_enc: [b*ts tw dae]
146
+
147
+ final_vec = T.cat([attn_enc, char_enc], dim=-1)
148
+
149
+ diac_out = self.classifier(self.dropout(final_vec))
150
+ #^ diac_out: [b*ts tw 7]
151
+
152
+ diac_out = diac_out.view(-1, max_words, self.max_word_len, self.num_classes)
153
+ #^ diac_out: [b ts tw 7]
154
+
155
+ if not self.batch_first:
156
+ diac_out = diac_out.swapaxes(1, 0)
157
+
158
+ return diac_out
159
+
160
+
161
+ def step(self, xt, yt, mask=None):
162
+ xt[1] = xt[1].to(self.device)
163
+ xt[2] = xt[2].to(self.device)
164
+
165
+ yt = yt.to(self.device)
166
+ #^ yt: [b ts tw]
167
+
168
+ diac, _ = self(*xt)
169
+ loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1))
170
+
171
+ return loss
172
+
173
+ def predict(self, dataloader):
174
+ training = self.training
175
+ self.eval()
176
+
177
+ preds = {'haraka': [], 'shadda': [], 'tanween': []}
178
+ print("> Predicting...")
179
+ for inputs, _ in tqdm(dataloader, total=len(dataloader)):
180
+ inputs[0] = inputs[0].to(self.device)
181
+ inputs[1] = inputs[1].to(self.device)
182
+ diac, _ = self(*inputs)
183
+
184
+ output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1)
185
+ #^ [b ts tw]
186
+
187
+ haraka, tanween, shadda = flat_2_3head(output)
188
+
189
+ preds['haraka'].extend(haraka)
190
+ preds['tanween'].extend(tanween)
191
+ preds['shadda'].extend(shadda)
192
+
193
+ self.train(training)
194
+ return (
195
+ np.array(preds['haraka']),
196
+ np.array(preds["tanween"]),
197
+ np.array(preds["shadda"]),
198
+ )
199
+
200
+ class DiacritizerD3(nn.Module):
201
+ def __init__(self, config, device='cuda'):
202
+ super(DiacritizerD3, self).__init__()
203
+ self.max_word_len = config["train"]["max-word-len"]
204
+ self.max_sent_len = config["train"]["max-sent-len"]
205
+ self.char_embed_dim = config["train"]["char-embed-dim"]
206
+
207
+ self.sent_dropout_p = config["train"]["sent-dropout"]
208
+ self.diac_dropout_p = config["train"]["diac-dropout"]
209
+ self.vertical_dropout = config['train']['vertical-dropout']
210
+ self.recurrent_dropout = config['train']['recurrent-dropout']
211
+ self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied')
212
+ self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid')
213
+
214
+ self.sent_lstm_units = config["train"]["sent-lstm-units"]
215
+ self.word_lstm_units = config["train"]["word-lstm-units"]
216
+ self.decoder_units = config["train"]["decoder-units"]
217
+
218
+ self.sent_lstm_layers = config["train"]["sent-lstm-layers"]
219
+ self.word_lstm_layers = config["train"]["word-lstm-layers"]
220
+
221
+ self.cell = config['train'].get('rnn-cell', 'lstm')
222
+ self.num_layers = config["train"].get("num-layers", 2)
223
+ self.RNN_Layer = K_LSTM
224
+
225
+ self.batch_first = config['train'].get('batch-first', True)
226
+
227
+ self.baseline = config["train"].get("baseline", False)
228
+ self.device = device
229
+
230
+ def build(self, wembs: T.Tensor, abjad_size: int):
231
+ self.closs = F.cross_entropy
232
+ self.bloss = F.binary_cross_entropy_with_logits
233
+
234
+ rnn_kargs = dict(
235
+ recurrent_dropout_mode=self.recurrent_dropout_mode,
236
+ recurrent_activation=self.recurrent_activation,
237
+ )
238
+
239
+ self.sent_lstm = self.RNN_Layer(
240
+ input_size=300,
241
+ hidden_size=self.sent_lstm_units,
242
+ num_layers=self.sent_lstm_layers,
243
+ bidirectional=True,
244
+ vertical_dropout=self.vertical_dropout,
245
+ recurrent_dropout=self.recurrent_dropout,
246
+ batch_first=self.batch_first,
247
+ **rnn_kargs,
248
+ )
249
+
250
+ self.word_lstm = self.RNN_Layer(
251
+ input_size=self.sent_lstm_units * 2 + self.char_embed_dim,
252
+ hidden_size=self.word_lstm_units,
253
+ num_layers=self.word_lstm_layers,
254
+ bidirectional=True,
255
+ vertical_dropout=self.vertical_dropout,
256
+ recurrent_dropout=self.recurrent_dropout,
257
+ batch_first=self.batch_first,
258
+ return_states=True,
259
+ **rnn_kargs,
260
+ )
261
+
262
+ self.char_embs = nn.Embedding(
263
+ abjad_size,
264
+ self.char_embed_dim,
265
+ padding_idx=0,
266
+ )
267
+
268
+ self.attention = Attention(
269
+ kind="dot",
270
+ query_dim=self.word_lstm_units * 2,
271
+ input_dim=self.sent_lstm_units * 2,
272
+ )
273
+
274
+ self.lstm_decoder = self.RNN_Layer(
275
+ input_size=self.word_lstm_units * 2 + self.attention.Dout + 8,
276
+ hidden_size=self.word_lstm_units * 2,
277
+ num_layers=1,
278
+ bidirectional=False,
279
+ vertical_dropout=self.vertical_dropout,
280
+ recurrent_dropout=self.recurrent_dropout,
281
+ batch_first=self.batch_first,
282
+ return_states=True,
283
+ **rnn_kargs,
284
+ )
285
+
286
+ self.word_embs = T.tensor(wembs, dtype=T.float32)
287
+
288
+ self.classifier = nn.Linear(self.lstm_decoder.hidden_size, 15)
289
+ self.dropout = nn.Dropout(0.2)
290
+
291
+ def forward(self, sents, words, labels):
292
+ #^ sents : [b ts]
293
+ #^ words : [b ts tw]
294
+ #^ labels: [b ts tw]
295
+
296
+ word_mask = words.ne(0.).float()
297
+ #^ word_mask: [b ts tw]
298
+
299
+ if self.training:
300
+ q = 1.0 - self.sent_dropout_p
301
+ sdo = T.bernoulli(T.full(sents.shape, q))
302
+ sents_do = sents * sdo.long()
303
+ #^ sents_do : [b ts] ; DO(ts)
304
+ wembs = self.word_embs[sents_do]
305
+ #^ wembs : [b ts dw] ; DO(ts)
306
+ else:
307
+ wembs = self.word_embs[sents]
308
+ #^ wembs : [b ts dw]
309
+
310
+ sent_enc = self.sent_lstm(wembs.to(self.device))
311
+ #^ sent_enc : [b ts dwe]
312
+
313
+ sentword_do = sent_enc.unsqueeze(2)
314
+ #^ sentword_do : [b ts _ dwe]
315
+
316
+ sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1))
317
+ #^ sentword_do : [b ts tw dwe]
318
+
319
+ word_index = words.view(-1, self.max_word_len)
320
+ #^ word_index: [b*ts tw]?
321
+
322
+ cembs = self.char_embs(word_index)
323
+ #^ cembs : [b*ts tw dc]
324
+
325
+ sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2)
326
+ #^ sentword_do : [b*ts tw dwe]
327
+
328
+ char_embs = T.cat([cembs, sentword_do], dim=-1)
329
+ #^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe
330
+
331
+ char_enc, _ = self.word_lstm(char_embs)
332
+ #^ char_enc: [b*ts tw dce]
333
+
334
+ char_enc_reshaped = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units * 2)
335
+ #^ char_enc: [b ts tw dce]
336
+
337
+ omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device)
338
+ attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask)
339
+ #^ attn_enc: [b ts tw dae]
340
+
341
+ attn_enc = attn_enc.view(-1, self.max_sent_len*self.max_word_len, self.attention.Dout)
342
+ #^ attn_enc: [b*ts tw dae]
343
+
344
+ if self.training and self.diac_dropout_p > 0:
345
+ q = 1.0 - self.diac_dropout_p
346
+ ddo = T.bernoulli(T.full(labels.shape[:-1], q))
347
+ labels = labels * ddo.unsqueeze(-1).long().to(self.device)
348
+ #^ labels : [b ts tw] ; DO(ts)
349
+
350
+ labels = labels.view(-1, self.max_sent_len*self.max_word_len, 8).float()
351
+ #^ labels: [b*ts tw 8]
352
+
353
+ char_enc = char_enc.view(-1, self.max_sent_len*self.max_word_len, self.word_lstm_units * 2)
354
+
355
+ final_vec = T.cat([attn_enc, char_enc, labels], dim=-1)
356
+ #^ final_vec: [b ts*tw dae+8]
357
+
358
+ dec_out, _ = self.lstm_decoder(final_vec)
359
+ #^ dec_out: [b*ts tw du]
360
+
361
+ dec_out = dec_out.reshape(-1, self.max_word_len, self.lstm_decoder.hidden_size)
362
+
363
+ diac_out = self.classifier(self.dropout(dec_out))
364
+ #^ diac_out: [b*ts tw 7]
365
+
366
+ diac_out = diac_out.view(-1, self.max_sent_len, self.max_word_len, 15)
367
+ #^ diac_out: [b ts tw 7]
368
+
369
+ if not self.batch_first:
370
+ diac_out = diac_out.swapaxes(1, 0)
371
+
372
+ return diac_out, attn_map
373
+
374
+ def predict_sample(self, sents, words, labels):
375
+
376
+ word_mask = words.ne(0.).float()
377
+ #^ mask: [b ts tw 1]
378
+
379
+ if self.training:
380
+ q = 1.0 - self.sent_dropout_p
381
+ sdo = T.bernoulli(T.full(sents.shape, q))
382
+ sents_do = sents * sdo.long()
383
+ #^ sents_do : [b ts] ; DO(ts)
384
+ wembs = self.word_embs[sents_do]
385
+ #^ wembs : [b ts dw] ; DO(ts)
386
+ else:
387
+ wembs = self.word_embs[sents]
388
+ #^ wembs : [b ts dw]
389
+
390
+ sent_enc = self.sent_lstm(wembs.to(self.device))
391
+ #^ sent_enc : [b ts dwe]
392
+
393
+ sentword_do = sent_enc.unsqueeze(2)
394
+ #^ sentword_do : [b ts _ dwe]
395
+
396
+ sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1))
397
+ #^ sentword_do : [b ts tw dwe]
398
+
399
+ word_index = words.view(-1, self.max_word_len)
400
+ #^ word_index: [b*ts tw]?
401
+
402
+ cembs = self.char_embs(word_index)
403
+ #^ cembs : [b*ts tw dc]
404
+
405
+ sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2)
406
+ #^ sentword_do : [b*ts tw dwe]
407
+
408
+ char_embs = T.cat([cembs, sentword_do], dim=-1)
409
+ #^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe
410
+
411
+ char_enc, _ = self.word_lstm(char_embs)
412
+ #^ char_enc: [b*ts tw dce]
413
+ #^ word_states: ([b*ts dce], [b*ts dce])
414
+
415
+ char_enc = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units*2)
416
+ #^ char_enc: [b ts tw dce]
417
+
418
+ omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device)
419
+ attn_enc, _ = self.attention(char_enc, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask)
420
+ #^ attn_enc: [b ts tw dae]
421
+
422
+ all_out = T.zeros(*char_enc.size()[:-1], 15).to(self.device)
423
+ #^ all_out: [b ts tw 7]
424
+
425
+ batch_sz = char_enc.size()[0]
426
+ #^ batch_sz: b
427
+
428
+ zeros = T.zeros(1, batch_sz, self.lstm_decoder.hidden_size).to(self.device)
429
+ #^ zeros: [1 b du]
430
+
431
+ bos_tag = T.tensor([0,0,0,0,0,1,0,0]).unsqueeze(0)
432
+ #^ bos_tag: [1 8]
433
+
434
+ prev_label = T.cat([bos_tag]*batch_sz).to(self.device).float()
435
+ # bos_vec = T.cat([bos_tag]*batch_sz).to(self.device).float()
436
+ #^ prev_label: [b 8]
437
+
438
+ for ts in range(self.max_sent_len):
439
+ dec_hx = (zeros, zeros)
440
+ #^ dec_hx: [1 b du]
441
+ for tw in range(self.max_word_len):
442
+ final_vec = T.cat([attn_enc[:,ts,tw,:], char_enc[:,ts,tw,:], prev_label], dim=-1).unsqueeze(1)
443
+ #^ final_vec: [b 1 dce+8]
444
+ dec_out, dec_hx = self.lstm_decoder(final_vec, dec_hx)
445
+ #^ dec_out: [b 1 du]
446
+ dec_out = dec_out.squeeze(0)
447
+ dec_out = dec_out.transpose(0,1)
448
+
449
+ logits_raw = self.classifier(self.dropout(dec_out))
450
+ #^ logits_raw: [b 1 15]
451
+
452
+ out_idx = T.max(T.softmax(logits_raw.squeeze(), dim=-1), dim=-1)[1]
453
+
454
+ haraka, tanween, shadda = flat2_3head(out_idx.detach().cpu().numpy())
455
+
456
+ haraka_onehot = T.eye(6)[haraka].float().to(self.device)
457
+ #^ haraka_onehot+bos_tag: [b 6]
458
+
459
+ tanween = T.tensor(tanween).float().unsqueeze(-1).to(self.device)
460
+ shadda = T.tensor(shadda).float().unsqueeze(-1).to(self.device)
461
+
462
+ prev_label = T.cat([haraka_onehot, tanween, shadda], dim=-1)
463
+
464
+ all_out[:,ts,tw,:] = logits_raw.squeeze()
465
+
466
+ if not self.batch_first:
467
+ all_out = all_out.swapaxes(1, 0)
468
+
469
+ return all_out
470
+
471
+ def step(self, xt, yt, mask=None):
472
+ xt[1] = xt[1].to(self.device)
473
+ xt[2] = xt[2].to(self.device)
474
+ #^ yt: [b ts tw]
475
+ yt = yt.to(self.device)
476
+
477
+ if self.training:
478
+ diac, _ = self(*xt)
479
+ else:
480
+ diac = self.predict_sample(*xt)
481
+ #^ diac[0] : [b ts tw 5]
482
+
483
+ loss = self.closs(diac.view(-1,15), yt.view(-1))
484
+ return loss
485
+
486
+ def predict(self, dataloader):
487
+ training = self.training
488
+ self.eval()
489
+
490
+ preds = {'haraka': [], 'shadda': [], 'tanween': []}
491
+ print("> Predicting...")
492
+ for inputs, _ in tqdm(dataloader, total=len(dataloader)):
493
+ inputs[1] = inputs[1].to(self.device)
494
+ inputs[2] = inputs[2].to(self.device)
495
+ diac = self.predict_sample(*inputs)
496
+ output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1)
497
+ #^ [b ts tw]
498
+
499
+ haraka, tanween, shadda = flat_2_3head(output)
500
+
501
+ preds['haraka'].extend(haraka)
502
+ preds['tanween'].extend(tanween)
503
+ preds['shadda'].extend(shadda)
504
+
505
+ self.train(training)
506
+ return (
507
+ np.array(preds['haraka']),
508
+ np.array(preds["tanween"]),
509
+ np.array(preds["shadda"]),
510
+ )
511
+
512
+ if __name__ == "__main__":
513
+
514
+ import yaml
515
+ config_path = "configs/dd/config_d2.yaml"
516
+ model_path = "models/tashkeela-d2.pt"
517
+ with open(config_path, 'r', encoding="utf-8") as file:
518
+ config = yaml.load(file, Loader=yaml.FullLoader)
519
+
520
+ data_utils = DatasetUtils(config)
521
+ vocab_size = len(data_utils.letter_list)
522
+ word_embeddings = data_utils.embeddings
523
+
524
+ model = DiacritizerD2(config, device='cpu')
525
+ model.build(word_embeddings, vocab_size)
526
+ model.load_state_dict(T.load(model_path, map_location=T.device('cpu'))["state_dict"])
model_partial.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+ import yaml
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+
6
+ import torch as T
7
+ from torch import nn
8
+ from torch import functional as F
9
+ from diac_utils import flat_2_3head
10
+
11
+ from model_dd import DiacritizerD2
12
+
13
+ class Readout(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_size: int,
17
+ out_size: int,
18
+ ):
19
+ super().__init__()
20
+ self.W1 = nn.Linear(in_size, in_size)
21
+ self.W2 = nn.Linear(in_size, out_size)
22
+
23
+ def forward(self, x: T.Tensor):
24
+ z = self.W1(x)
25
+ z = T.tanh(z)
26
+ z = self.W2(x)
27
+ return z
28
+
29
+ class WordDD_LSTM(nn.Module):
30
+ def __init__(
31
+ self,
32
+ feature_size: int,
33
+ num_classes: int = 13,
34
+ return_logits: bool = True,
35
+ ):
36
+ super().__init__()
37
+ self.feature_size = feature_size
38
+ self.num_classes = num_classes
39
+ self.return_logits = return_logits
40
+ self.cell = nn.LSTM(feature_size)
41
+ self.head = Readout(feature_size, num_classes)
42
+
43
+ def forward(self, x: T.Tensor):
44
+ #^ x: [b tc dc]
45
+ z = self.cell(x)
46
+ #^ z: [b tc @dc]
47
+ y = self.head(z)
48
+ #^ y: [b tc Classes]
49
+ yhat = y
50
+ if not self.return_logits:
51
+ yhat = F.softmax(yhat, dim=1)
52
+ #^ yhat: [b tc @Classes]
53
+ return yhat
54
+
55
+ class PartialDiacOutput(NamedTuple):
56
+ preds_hard: T.Tensor
57
+ preds_ctxt_logit: T.Tensor
58
+ preds_base_logit: T.Tensor
59
+
60
+
61
+ class PartialDD(nn.Module):
62
+ def __init__(
63
+ self,
64
+ config: dict,
65
+ # feature_size: int,
66
+ # confidence_threshold: float,
67
+ d2=False
68
+ ):
69
+ super().__init__()
70
+ self._built = False
71
+ self.no_diac_id = 0
72
+ self._dummy = nn.Parameter(T.ones(1, 1))
73
+
74
+ self.config = config
75
+ self.sentence_diac = DiacritizerD2(self.config)
76
+
77
+ self.eval()
78
+
79
+ @property
80
+ def device(self):
81
+ return self._dummy.device
82
+
83
+ @property
84
+ def tokenizer(self):
85
+ return self.sentence_diac.tokenizer
86
+
87
+ def load_state_dict(
88
+ self,
89
+ state_dict: dict
90
+ ):
91
+ self.sentence_diac.load_state_dict(state_dict)
92
+
93
+ def _slim_batch(
94
+ self,
95
+ toke_ids: T.Tensor,
96
+ char_ids: T.Tensor,
97
+ diac_ids: T.Tensor,
98
+ subword_lengths: T.Tensor,
99
+ ):
100
+ #^ toke_ids: [b tt]
101
+ #^ char_ids: [b tw tc]
102
+ #^ diac_ids: [b tw tc "13"]
103
+ #^ subword_lengths: [b tw]
104
+ token_nonpad_mask = toke_ids.ne(self.tokenizer.pad_token_id)
105
+ Ttoken = token_nonpad_mask.sum(1).max()
106
+ toke_ids = toke_ids[:, :Ttoken]
107
+
108
+ char_nonpad_mask = char_ids.ne(0)
109
+ Tword = char_nonpad_mask.any(2).sum(1).max()
110
+ Tchar = char_nonpad_mask.sum(2).max()
111
+ char_ids = char_ids[:, :Tword, :Tchar]
112
+ diac_ids = diac_ids[:, :Tword, :Tchar]
113
+ subword_lengths = subword_lengths[:, :Tword]
114
+
115
+ return toke_ids, char_ids, diac_ids, subword_lengths
116
+
117
+ def word_diac(
118
+ self,
119
+ toke_ids: T.Tensor,
120
+ char_ids: T.Tensor,
121
+ diac_ids: T.Tensor,
122
+ subword_lengths: T.Tensor,
123
+ *,
124
+ shape: tuple = None,
125
+ ):
126
+ if shape is None:
127
+ toke_ids, char_ids, diac_ids, subword_lengths = self._slim_batch(
128
+ toke_ids, char_ids, diac_ids, subword_lengths
129
+ )
130
+ else:
131
+ Nb, Tw, Tc = shape
132
+ toke_ids = toke_ids[:, :]
133
+ char_ids = char_ids[:, :Tw, :Tc]
134
+ diac_ids = diac_ids[:, :Tw, :Tc, :]
135
+ subword_lengths = subword_lengths[:, :Tw]
136
+ Nb, Tw, Tc = char_ids.shape
137
+ # Tw = min(Tw, word_ids.shape[1])
138
+ #^ word_ids: [b tt]
139
+ #^ char_ids: [b tw tc]
140
+ # wids_flat = word_ids[:, Tw].reshape(Nb * Tw, 1)
141
+ # cids_flat = char_ids[:, Tw].reshape(Nb * Tw, 1, Tc)
142
+ # z = self.sentence_diac(wids_flat, cids_flat)
143
+
144
+ sent_word_strides = subword_lengths.cumsum(1)
145
+ assert tuple(subword_lengths.shape) == (Nb, Tw), f"{subword_lengths.shape} != {(Nb, Tw)=}"
146
+ max_tokens_per_word: int = subword_lengths.max().int().item()
147
+ word_x = T.zeros(Nb, Tw, max_tokens_per_word).to(toke_ids)
148
+ for i_b in range(toke_ids.shape[0]):
149
+ sent_i = toke_ids[i_b]
150
+ start_iw = 0
151
+ for i_word, end_iw in enumerate(sent_word_strides[i_b]):
152
+ if end_iw == start_iw: break
153
+ word = sent_i[start_iw:end_iw]
154
+ word_x[i_b, i_word, 0 : end_iw - start_iw] = word
155
+ start_iw = end_iw
156
+ #^ word_x: [b tw tt]
157
+ word_x = word_x.reshape(Nb * Tw, max_tokens_per_word)
158
+ cids_flat = char_ids.reshape(Nb * Tw, 1, Tc)
159
+ word_lengths = subword_lengths.reshape(Nb * Tw, 1)
160
+
161
+ z = self.sentence_diac(
162
+ word_x,
163
+ cids_flat,
164
+ diac_ids.reshape(Nb*Tw, Tc, -1),
165
+ subword_lengths=word_lengths,
166
+ )
167
+ # Nc = z.shape[-1]
168
+ #^ z: [b*tw, 1, tc, "13"]
169
+ z = z.reshape(Nb, Tw, Tc, -1)
170
+ return z
171
+
172
+ def forward(
173
+ self,
174
+ word_ids: T.Tensor,
175
+ char_ids: T.Tensor,
176
+ _labels: T.Tensor,
177
+ # ground_truth: T.Tensor,
178
+ # padding_mask: T.BoolTensor,
179
+ *,
180
+ eval_only: str = None,
181
+ subword_lengths: T.Tensor,
182
+ return_extra: bool = False
183
+ ):
184
+ # assert self._built and not self.training
185
+ assert not self.training
186
+ #^ word_ids: [b tw]
187
+ #^ char_ids: [b tw tc]
188
+ #^ ground_truth: [b tw tc]
189
+
190
+ padding_mask = char_ids.eq(0)
191
+ #^ padding_mask: [b tw tc]
192
+
193
+ if True or eval_only != 'base':
194
+ y_ctxt = self.sentence_diac(
195
+ word_ids,
196
+ char_ids,
197
+ _labels,
198
+ subword_lengths=subword_lengths,
199
+ )
200
+ out_shape = y_ctxt.shape[:-1]
201
+ else:
202
+ out_shape = self.sentence_diac._slim_batch_size(
203
+ word_ids,
204
+ char_ids,
205
+ _labels,
206
+ subword_lengths,
207
+ )[1].shape
208
+ #^ y_ctxt: [b tw tc "13"]
209
+ if eval_only == 'ctxt':
210
+ return y_ctxt.argmax(-1)
211
+
212
+ y_base = self.word_diac(
213
+ word_ids,
214
+ char_ids,
215
+ _labels,
216
+ subword_lengths,
217
+ shape=out_shape
218
+ )
219
+ #^ y_base: [b tw tc "13"]
220
+ if eval_only == 'base':
221
+ return y_base.argmax(-1)
222
+
223
+ ypred_ctxt = y_ctxt.argmax(-1)
224
+ ypred_base = y_base.argmax(-1)
225
+ #^ ypred: [b tw tc _]
226
+
227
+ # Maybe for eval
228
+ # ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id
229
+ # return ypred_ctxt
230
+ ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id
231
+ if not return_extra:
232
+ return ypred_ctxt
233
+ else:
234
+ return PartialDiacOutput(ypred_ctxt, y_ctxt, y_base)
235
+
236
+ def step(self, xt, yt, mask=None):
237
+ raise NotImplementedError
238
+ xt[1] = xt[1].to(self.device)
239
+ xt[2] = xt[2].to(self.device)
240
+
241
+ yt = yt.to(self.device)
242
+ #^ yt: [b ts tw]
243
+
244
+ diac, _ = self(*xt) # xt: (word_ids, char_ids, _labels)
245
+ loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1))
246
+
247
+ return loss
248
+
249
+ def predict_partial(
250
+ self,
251
+ dataloader,
252
+ return_extra=False,
253
+ eval_only: str = None,
254
+ ):
255
+ training = self.training
256
+ self.eval()
257
+
258
+ preds = {
259
+ 'haraka': [],
260
+ 'shadda': [],
261
+ 'tanween': [],
262
+ 'diacs': [],
263
+ 'y_ctxt': [],
264
+ 'y_base': [],
265
+ }
266
+ print("> Predicting...")
267
+ # breakpoint()
268
+ for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)):
269
+ # if i_batch > 10:
270
+ # break
271
+ #^ inputs: [toke_ids, char_ids, diac_ids]
272
+ inputs[0] = inputs[0].to(self.device) #< toke_ids
273
+ inputs[1] = inputs[1].to(self.device) #< char_ids
274
+ # inputs[2] = inputs[2].to(self.device) #< diac_ids
275
+
276
+ if self._use_d2:
277
+ subword_lengths = T.ones_like(inputs[0])
278
+ subword_lengths[inputs[0] == 0] = 0
279
+
280
+ with T.no_grad():
281
+ output = self(
282
+ *inputs,
283
+ subword_lengths=subword_lengths,
284
+ return_extra=return_extra,
285
+ eval_only=eval_only,
286
+ )
287
+
288
+ # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
289
+ if return_extra:
290
+ assert isinstance(output, PartialDiacOutput)
291
+ marks = output.preds_hard
292
+ preds['diacs'].extend(list(marks.detach().cpu().numpy()))
293
+ preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy()))
294
+ preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy()))
295
+ else:
296
+ assert isinstance(output, T.Tensor)
297
+ marks = output
298
+ preds['diacs'].extend(list(marks.detach().cpu().numpy()))
299
+ #^ [b ts tw]
300
+
301
+ haraka, tanween, shadda = flat_2_3head(marks)
302
+
303
+ preds['haraka'].extend(haraka)
304
+ preds['tanween'].extend(tanween)
305
+ preds['shadda'].extend(shadda)
306
+
307
+ self.train(training)
308
+ return {
309
+ 'diacritics': (
310
+ #! FIXME! Due to batch slimming, output diacritics may need padding.
311
+ np.array(preds['haraka']),
312
+ np.array(preds["tanween"]),
313
+ np.array(preds["shadda"]),
314
+ ),
315
+ 'other': ( # Would be empty when !return_extra
316
+ preds['y_ctxt'],
317
+ preds['y_base'],
318
+ preds['diacs'],
319
+ )
320
+ }
321
+
322
+ def predict(self, dataloader):
323
+ training = self.training
324
+ self.eval()
325
+
326
+ preds = {'haraka': [], 'shadda': [], 'tanween': []}
327
+ print("> Predicting...")
328
+ for inputs, _ in tqdm(dataloader, total=len(dataloader)):
329
+ inputs[0] = inputs[0].to(self.device)
330
+ inputs[1] = inputs[1].to(self.device)
331
+ output = self(*inputs)
332
+
333
+ # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
334
+ marks = output
335
+ #^ [b ts tw]
336
+
337
+ haraka, tanween, shadda = flat_2_3head(marks)
338
+
339
+ preds['haraka'].extend(haraka)
340
+ preds['tanween'].extend(tanween)
341
+ preds['shadda'].extend(shadda)
342
+
343
+ self.train(training)
344
+ return (
345
+ np.array(preds['haraka']),
346
+ np.array(preds["tanween"]),
347
+ np.array(preds["shadda"]),
348
+ )
predict.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, Union, Tuple
2
+ from collections import Counter
3
+
4
+ import argparse
5
+ import os
6
+
7
+ import yaml
8
+ from pyarabic.araby import tokenize, strip_tatweel
9
+ from tqdm import tqdm
10
+
11
+ import numpy as np
12
+ import torch as T
13
+ from torch.utils.data import DataLoader
14
+
15
+ from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line
16
+ from model_partial import PartialDD
17
+ from data_utils import DatasetUtils
18
+ from dataloader import DataRetriever
19
+ from segment import segment
20
+
21
+ class Predictor:
22
+ def __init__(self, config, text):
23
+
24
+ self.data_utils = DatasetUtils(config)
25
+ vocab_size = len(self.data_utils.letter_list)
26
+ word_embeddings = self.data_utils.embeddings
27
+
28
+ stride = config["segment"]["stride"]
29
+ window = config["segment"]["window"]
30
+ min_window = config["segment"]["min-window"]
31
+
32
+ segments, mapping = segment([text], stride, window, min_window)
33
+
34
+ mapping_lines = []
35
+ for sent_idx, seg_idx, word_idx, char_idx in mapping:
36
+ mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"]
37
+
38
+ self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines)
39
+ self.original_lines = [text]
40
+ self.segments = segments
41
+
42
+ self.device = T.device(
43
+ config['predictor'].get('device', 'cuda:0')
44
+ if T.cuda.is_available() else 'cpu'
45
+ )
46
+
47
+ self.model = PartialDD(config, d2=True)
48
+ self.model.sentence_diac.build(word_embeddings, vocab_size)
49
+ state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
50
+ self.model.load_state_dict(state_dict)
51
+ self.model.to(self.device)
52
+ self.model.eval()
53
+
54
+ self.data_loader = DataLoader(
55
+ DataRetriever(self.data_utils, segments),
56
+ batch_size=config["predictor"].get("batch-size", 32),
57
+ shuffle=False,
58
+ num_workers=config['loader'].get('num-workers', 0),
59
+ )
60
+
61
+ class PredictTri(Predictor):
62
+ def __init__(self, config, text):
63
+ super().__init__(config, text)
64
+ self.diacritics = {
65
+ "FATHA": 1,
66
+ "KASRA": 2,
67
+ "DAMMA": 3,
68
+ "SUKUN": 4
69
+ }
70
+ self.votes: Union[Counter[int], Counter[bool]] = Counter()
71
+
72
+ def count_votes(
73
+ self,
74
+ things: Union[Iterable[int], Iterable[bool]]
75
+ ):
76
+ self.votes.clear()
77
+ self.votes.update(things)
78
+ return self.votes.most_common(1)[0][0]
79
+
80
+ def predict_majority_vote(self):
81
+ y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader)
82
+ diacritized_lines = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
83
+ return diacritized_lines
84
+
85
+ def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
86
+ assert isinstance(self.model, PartialDD)
87
+ if not os.path.exists("dataset/cache/y_gen_diac.npy") or overwrite_cache:
88
+ if not os.path.exists("dataset/cache"):
89
+ os.mkdir("dataset/cache")
90
+ # segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True)
91
+ segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='ctxt')
92
+ T.save(segment_outputs, "dataset/cache/cache.pt")
93
+ else:
94
+ segment_outputs = T.load("dataset/cache/cache.pt")
95
+
96
+ y_gen_diac, y_gen_tanween, y_gen_shadda = segment_outputs['diacritics']
97
+ diacritized_lines, extra_for_lines = self.coalesce_votes_by_majority(
98
+ y_gen_diac, y_gen_tanween, y_gen_shadda,
99
+ )
100
+ extra_out = {
101
+ 'line_data': {
102
+ **extra_for_lines,
103
+ },
104
+ 'segment_data': {
105
+ **segment_outputs,
106
+ # 'logits': segment_outputs['logits'],
107
+ }
108
+ }
109
+ return diacritized_lines, extra_out
110
+
111
+ def coalesce_votes_by_majority(
112
+ self,
113
+ y_gen_diac: np.ndarray,
114
+ y_gen_tanween: np.ndarray,
115
+ y_gen_shadda: np.ndarray,
116
+ ):
117
+ prepped_lines_og = [' '.join(tokenize(strip_tatweel(line))) for line in self.original_lines]
118
+ max_line_chars = max(len(line) for line in prepped_lines_og)
119
+ diacritics_pred = np.full((len(self.original_lines), max_line_chars), fill_value=-1, dtype=int)
120
+
121
+ count_processed_sents = 0
122
+ do_break = False
123
+ diacritized_lines = []
124
+ for sent_idx, line in enumerate(tqdm(prepped_lines_og)):
125
+ count_processed_sents = sent_idx + 1
126
+ line = line.strip()
127
+ diacritized_line = ""
128
+ for char_idx, char in enumerate(line):
129
+ diacritized_line += char
130
+ char_vote_diacritic = []
131
+ # ? This is the voting part
132
+ if sent_idx not in self.mapping:
133
+ continue
134
+
135
+ mapping_s_i = self.mapping[sent_idx]
136
+ for seg_idx in mapping_s_i:
137
+ if self.data_utils.debug and seg_idx >= 256:
138
+ do_break = True
139
+ break
140
+
141
+ mapping_g_i = mapping_s_i[seg_idx]
142
+ for t_idx in mapping_g_i:
143
+
144
+ mapping_t_i = mapping_g_i[t_idx]
145
+ if char_idx in mapping_t_i:
146
+ c_idx = mapping_t_i.index(char_idx)
147
+ output_idx = np.s_[seg_idx, t_idx, c_idx]
148
+ diac_h3 = (y_gen_diac[output_idx], y_gen_tanween[output_idx], y_gen_shadda[output_idx])
149
+ diac_char_i = HARAKAT_MAP.index(diac_h3)
150
+ if c_idx < 13 and diac_char_i != 0:
151
+ char_vote_diacritic.append(diac_char_i)
152
+
153
+ if do_break:
154
+ break
155
+ if len(char_vote_diacritic) > 0:
156
+ char_mv_diac = self.count_votes(char_vote_diacritic)
157
+ diacritized_line += shakkel_char(*HARAKAT_MAP[char_mv_diac])
158
+ diacritics_pred[sent_idx, char_idx] = char_mv_diac
159
+ else:
160
+ diacritics_pred[sent_idx, char_idx] = 0
161
+ if do_break:
162
+ break
163
+
164
+ diacritized_lines += [diacritized_line.strip()]
165
+
166
+ print(f'[INFO] Cutting stats from {len(diacritics_pred)} to {count_processed_sents}')
167
+ extra = {
168
+ 'diac_pred': diacritics_pred[:count_processed_sents],
169
+ }
170
+ return diacritized_lines, extra
segment.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import os
4
+ import pickle as pkl
5
+
6
+ from tqdm import tqdm
7
+ from pyarabic.araby import tokenize, strip_tashkeel, strip_tatweel
8
+
9
+ def export(path, text):
10
+ with open(path, 'w', encoding="utf-8") as fout:
11
+ fout.write('\n'.join(text))
12
+
13
+ def segment(lines, stride, window_sz, min_window_sz):
14
+ segments, mapping = [], []
15
+ real_seg_idx = 0
16
+
17
+ for sent_idx, line in tqdm(enumerate(lines), total=len(lines)):
18
+ line: str = strip_tatweel(line)
19
+ line = line.strip()
20
+ tokens = tokenize(line)
21
+ if len(tokens) == 0: continue
22
+ if tokens[-1] == '\n': tokens = tokens[:-1]
23
+ seg_idx, idx = 0, 0
24
+ while idx < len(tokens):
25
+ window = tokens[idx:idx+window_sz]
26
+ if window_sz == -1: window = tokens
27
+ if len(window) < min_window_sz and seg_idx != 0: break
28
+
29
+ segment = ' '.join(window)
30
+ segments += [segment]
31
+ char_offset = len(strip_tashkeel(' '.join(tokens[:idx])))
32
+
33
+ if seg_idx > 0:
34
+ char_offset += 1
35
+
36
+ seg_tokens = tokenize(strip_tashkeel(segment))
37
+
38
+ j = 0
39
+ for st_idx, st in enumerate(seg_tokens):
40
+ for _ in range(len(st)):
41
+ mapping += [(sent_idx, real_seg_idx, st_idx, j+char_offset)]
42
+ j += 1
43
+ j += 1
44
+
45
+ real_seg_idx += 1
46
+ seg_idx += 1
47
+
48
+ if stride == -1: break
49
+
50
+ idx += (window_sz if stride >= window_sz else stride)
51
+
52
+ return segments, mapping
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser(description='Sentence Breaker')
56
+ parser.add_argument('-c', '--config', type=str,
57
+ default="config.yaml", help='Run Configs')
58
+ parser.add_argument('-d', '--data_dir', type=str,
59
+ default=None, help='Override for data path')
60
+ args = parser.parse_args()
61
+
62
+ with open(args.config, 'r', encoding="utf-8") as file:
63
+ config = yaml.load(file, Loader=yaml.FullLoader)
64
+
65
+ BASE_PATH = args.data_dir or config["paths"].get("base")
66
+
67
+ stride = config["segment"]["stride"]
68
+ window = config["segment"]["window"]
69
+ min_window = config["segment"]["min-window"]
70
+ export_map = config["segment"]["export-map"]
71
+
72
+ for fpath in tqdm(config["segment"]["files"]):
73
+ FILE_PATH = os.path.join(BASE_PATH, fpath)
74
+ SAVE_PATH = os.path.join(BASE_PATH, fpath[:-4] + f"-{stride}-{window}.txt")
75
+ MAP_PATH = os.path.join(BASE_PATH, fpath[:-4] + f"-{stride}-{window}.map")
76
+
77
+ with open(FILE_PATH, 'r', encoding="utf-8") as fin:
78
+ lines = fin.readlines()
79
+
80
+ segments, mapping = segment(lines, stride, window, min_window)
81
+
82
+ with open(SAVE_PATH, 'w', encoding="utf-8") as fout:
83
+ fout.write('\n'.join(segments))
84
+
85
+ if not export_map: continue
86
+
87
+ with open(MAP_PATH, 'w', encoding="utf-8") as fout:
88
+ for sent_idx, seg_idx, word_idx, char_idx in mapping:
89
+ fout.write(f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}\n")