bkhmsi commited on
Commit
bb42b73
·
1 Parent(s): f81acf7

restructured space

Browse files
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Partial Tashkeel
3
- emoji: 🐠
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.1.2
8
  app_file: app.py
9
- pinned: false
10
  license: cc-by-sa-3.0
11
  ---
12
 
 
1
  ---
2
+ title: Partial Arabic Diacritization
3
+ emoji: 🖋️
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.1.2
8
  app_file: app.py
9
+ pinned: true
10
  license: cc-by-sa-3.0
11
  ---
12
 
app.py CHANGED
@@ -9,12 +9,12 @@ output_path = "tashkeela-d2.pt"
9
  gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
10
  if not os.path.exists(output_path):
11
  model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
12
- gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False)
13
 
14
  output_path = "vocab.vec"
15
  if not os.path.exists(output_path):
16
  vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
17
- gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False)
18
 
19
  with open("config.yaml", 'r', encoding="utf-8") as file:
20
  config = yaml.load(file, Loader=yaml.FullLoader)
@@ -22,41 +22,99 @@ with open("config.yaml", 'r', encoding="utf-8") as file:
22
  config["train"]["max-sent-len"] = config["predictor"]["window"]
23
  config["train"]["max-token-count"] = config["predictor"]["window"] * 3
24
 
25
- def diacritze(text, do_partial):
26
- predictor = PredictTri(config, text)
27
- diacritized_lines = predictor.predict_partial(do_partial=do_partial)
 
 
 
 
 
 
 
 
 
 
28
  return diacritized_lines
29
 
30
- with gr.Blocks() as demo:
31
  gr.Markdown(
32
  """
33
  # Partial Diacritization: A Context-Contrastive Inference Approach
34
- ## Authors: Muhammad ElNokrashy, Badr AlKhamissi
 
35
  """)
36
 
37
- with gr.Row():
38
- check_box = gr.Checkbox(label="Partial", info="Apply Partial Diacritics or Full Diacritics")
39
- threshold_txt = gr.Textbox("")
40
-
41
- input_txt = gr.Textbox(
42
- placeholder="اكتب هنا",
43
- lines=5,
44
- label="Input",
45
- type='text',
46
- rtl=True,
47
- text_align='right',
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- output_txt = gr.Textbox(
51
- lines=5,
52
- label="Output",
53
- type='text',
54
- rtl=True,
55
- text_align='right',
56
- )
57
 
58
- btn = gr.Button(value="Shakkel")
59
- btn.click(diacritze, inputs=[input_txt, check_box], outputs=[output_txt])
60
 
61
  if __name__ == "__main__":
62
  demo.queue().launch(
 
9
  gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
10
  if not os.path.exists(output_path):
11
  model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
12
+ gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
13
 
14
  output_path = "vocab.vec"
15
  if not os.path.exists(output_path):
16
  vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
17
+ gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
18
 
19
  with open("config.yaml", 'r', encoding="utf-8") as file:
20
  config = yaml.load(file, Loader=yaml.FullLoader)
 
22
  config["train"]["max-sent-len"] = config["predictor"]["window"]
23
  config["train"]["max-token-count"] = config["predictor"]["window"] * 3
24
 
25
+ predictor = PredictTri(config)
26
+
27
+ def diacritze_full(text):
28
+ do_hard_mask = None
29
+ threshold = None
30
+ predictor.create_dataloader(text, False, do_hard_mask, threshold)
31
+ diacritized_lines = predictor.predict_partial(do_partial=False, lines=text.split('\n'))
32
+ return diacritized_lines
33
+
34
+ def diacritze_partial(text, mask_mode, threshold):
35
+ do_partial = True
36
+ predictor.create_dataloader(text, do_partial, mask_mode=="Hard", threshold)
37
+ diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
38
  return diacritized_lines
39
 
40
+ with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
41
  gr.Markdown(
42
  """
43
  # Partial Diacritization: A Context-Contrastive Inference Approach
44
+ ### Authors: Muhammad ElNokrashy, Badr AlKhamissi
45
+ ### Paper Link: TBD
46
  """)
47
 
48
+ with gr.Tab(label="Full Diacritization"):
49
+
50
+ full_input_txt = gr.Textbox(
51
+ placeholder="اكتب هنا",
52
+ lines=5,
53
+ label="Input",
54
+ type='text',
55
+ rtl=True,
56
+ text_align='right',
57
+ )
58
+
59
+ full_output_txt = gr.Textbox(
60
+ lines=5,
61
+ label="Output",
62
+ type='text',
63
+ rtl=True,
64
+ text_align='right',
65
+ show_copy_button=True,
66
+ )
67
+
68
+ full_btn = gr.Button(value="Shakkel")
69
+ full_btn.click(diacritze_full, inputs=[full_input_txt], outputs=[full_output_txt])
70
+
71
+ gr.Examples(
72
+ examples=[
73
+ "ولو حمل من مجلس الخيار ، ولم يمنع من الكلام"
74
+ ],
75
+ inputs=full_input_txt,
76
+ outputs=full_output_txt,
77
+ fn=diacritze_full,
78
+ cache_examples=True,
79
+ )
80
+
81
+ with gr.Tab(label="Partial Diacritization") as partial_settings:
82
+ with gr.Row():
83
+ masking_mode = gr.Radio(choices=["Hard", "Soft"], value="Hard", label="Masking Mode")
84
+ threshold_slider = gr.Slider(label="Soft Masking Threshold", minimum=0, maximum=1, value=0.1)
85
+
86
+ partial_input_txt = gr.Textbox(
87
+ placeholder="اكتب هنا",
88
+ lines=5,
89
+ label="Input",
90
+ type='text',
91
+ rtl=True,
92
+ text_align='right',
93
+ )
94
+
95
+ partial_output_txt = gr.Textbox(
96
+ lines=5,
97
+ label="Output",
98
+ type='text',
99
+ rtl=True,
100
+ text_align='right',
101
+ show_copy_button=True,
102
+ )
103
+
104
+ partial_btn = gr.Button(value="Shakkel")
105
+ partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider], outputs=[partial_output_txt])
106
+
107
+ gr.Examples(
108
+ examples=[
109
+ ["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Hard", 0],
110
+ ],
111
+ inputs=[partial_input_txt, masking_mode, threshold_slider],
112
+ outputs=partial_output_txt,
113
+ fn=diacritze_partial,
114
+ cache_examples=True,
115
+ )
116
 
 
 
 
 
 
 
 
117
 
 
 
118
 
119
  if __name__ == "__main__":
120
  demo.queue().launch(
data_utils.py CHANGED
@@ -26,7 +26,7 @@ class DatasetUtils:
26
  self.max_sent_len = config["train"]["max-sent-len"]
27
  self.max_token_count = config["train"]["max-token-count"]
28
  self.pad_target_val = -100
29
- self.pad_char_id = du.LETTER_LIST.index('<pad>')
30
 
31
  self.markov_signal = config['train'].get('markov-signal', False)
32
  self.batch_first = config['train'].get('batch-first', True)
 
26
  self.max_sent_len = config["train"]["max-sent-len"]
27
  self.max_token_count = config["train"]["max-token-count"]
28
  self.pad_target_val = -100
29
+ self.pad_char_id = du.DIAC_PAD_IDX #LETTER_LIST.index('<pad>')
30
 
31
  self.markov_signal = config['train'].get('markov-signal', False)
32
  self.batch_first = config['train'].get('batch-first', True)
diac_utils.py CHANGED
@@ -37,6 +37,8 @@ HARAKAT_MAP = [
37
  (0,0,0), #< Padding == -1 (also for spaces)
38
  ]
39
 
 
 
40
  SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
41
  LETTER_LIST = SPECIAL_TOKENS + list("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىي")
42
  CLASSES_LIST = [' ', 'َ', 'ً', 'ُ', 'ٌ', 'ِ', 'ٍ', 'ْ', 'ّ', 'َّ', 'ًّ', 'ُّ', 'ٌّ', 'ِّ', 'ٍّ']
@@ -63,13 +65,13 @@ def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str:
63
  return returned_text
64
 
65
  def diac_ids_of_line(line: str):
66
- words = tokenize(line)
67
  diacs = []
 
68
  for word in words:
69
  word_chars = split_word_on_characters_with_diacritics(word)
70
- cx, cy, cy_3head = create_label_for_word(word_chars)
71
  diacs.extend(cy)
72
- diacs.append(-1)
73
  return np.array(diacs[:-1])
74
 
75
  def strip_unknown_tashkeel(word: str):
@@ -77,6 +79,23 @@ def strip_unknown_tashkeel(word: str):
77
  return word
78
  return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def split_word_on_characters_with_diacritics(word: str):
81
  '''
82
  TODO! Make faster without deque and looping
@@ -100,6 +119,18 @@ def split_word_on_characters_with_diacritics(word: str):
100
  return chars_w_diac
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def char_type(char: str):
104
  if char in LETTER_LIST:
105
  return LETTER_LIST.index(char)
@@ -220,4 +251,4 @@ def flat2_3head(diac_idx):
220
  tanween += [c_out[1]]
221
  shadda += [c_out[2]]
222
 
223
- return np.array(haraka), np.array(tanween), np.array(shadda)
 
37
  (0,0,0), #< Padding == -1 (also for spaces)
38
  ]
39
 
40
+ DIAC_PAD_IDX = -1
41
+
42
  SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
43
  LETTER_LIST = SPECIAL_TOKENS + list("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىي")
44
  CLASSES_LIST = [' ', 'َ', 'ً', 'ُ', 'ٌ', 'ِ', 'ٍ', 'ْ', 'ّ', 'َّ', 'ًّ', 'ُّ', 'ٌّ', 'ِّ', 'ٍّ']
 
65
  return returned_text
66
 
67
  def diac_ids_of_line(line: str):
 
68
  diacs = []
69
+ words = tokenize(line)
70
  for word in words:
71
  word_chars = split_word_on_characters_with_diacritics(word)
72
+ _cx, cy, _cy_3head = create_label_for_word(word_chars)
73
  diacs.extend(cy)
74
+ diacs.append(DIAC_PAD_IDX)
75
  return np.array(diacs[:-1])
76
 
77
  def strip_unknown_tashkeel(word: str):
 
79
  return word
80
  return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)
81
 
82
+ def create_gt_labels(lines):
83
+ gt_labels = []
84
+ for line in lines:
85
+ # gt_labels_line = []
86
+ # tokens = tokenize(line.strip())
87
+ # for w_idx, word in enumerate(tokens):
88
+ # split_word = self.split_word_on_characters_with_diacritics(word)
89
+ # _, cy_flat, _ = du.create_label_for_word(split_word)
90
+
91
+ # gt_labels_line.extend(cy_flat)
92
+ # if w_idx+1 < len(tokens):
93
+ # gt_labels_line += [0]
94
+
95
+ gt_labels_line = diac_ids_of_line(line)
96
+ gt_labels.append(gt_labels_line)
97
+ return gt_labels
98
+
99
  def split_word_on_characters_with_diacritics(word: str):
100
  '''
101
  TODO! Make faster without deque and looping
 
119
  return chars_w_diac
120
 
121
 
122
+ def load_lines(path: str, *, strip: bool):
123
+ with open(path, 'r', encoding="utf-8", newline='\n') as fin:
124
+ if strip:
125
+ original_lines = [strip_tashkeel(normalize_spaces(line)) for line in fin.readlines()]
126
+ else:
127
+ original_lines = [normalize_spaces(line) for line in fin.readlines()]
128
+ return original_lines
129
+
130
+ def normalize_spaces(line: str):
131
+ return ' '.join(tokenize(line.strip()))
132
+
133
+
134
  def char_type(char: str):
135
  if char in LETTER_LIST:
136
  return LETTER_LIST.index(char)
 
251
  tanween += [c_out[1]]
252
  shadda += [c_out[2]]
253
 
254
+ return np.array(haraka), np.array(tanween), np.array(shadda)
gradio_cached_examples/16/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Output,flag,username,timestamp
2
+ ولو حمَل من مجلسِ الخيارِ ، ولم يُمنعْ من الكلام,,,2024-01-11 01:33:39.114395
gradio_cached_examples/6/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Output,flag,username,timestamp
2
+ وَلَوْ حَمَلَ مِنْ مَجْلِسِ الْخِيَارِ ، وَلَمْ يُمْنَعْ مِنْ الْكَلَامِ,,,2024-01-11 01:30:56.446393
predict.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
  import torch as T
13
  from torch.utils.data import DataLoader
14
 
15
- from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line
16
  from model_partial import PartialDD
17
  from model_dd import DiacritizerD2
18
  from data_utils import DatasetUtils
@@ -31,10 +31,21 @@ def apply_tashkeel(
31
  diacs: Union[np.ndarray, T.Tensor]
32
  ):
33
  line_w_diacs = ""
34
- diacs_h3 = DatasetUtils.flat2_3head(diacs)
35
- for ch, tashkeel in zip(line, zip(*diacs_h3)):
 
 
 
 
 
36
  line_w_diacs += ch
37
- line_w_diacs += DatasetUtils.shakkel_char(*tashkeel)
 
 
 
 
 
 
38
  return line_w_diacs
39
 
40
  def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1):
@@ -80,29 +91,16 @@ def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contra
80
  line = apply_tashkeel(line, line_diacs)
81
  output.append(line)
82
 
83
- return '\n'.join(output)
84
 
85
  class Predictor:
86
- def __init__(self, config, text):
87
 
88
  self.data_utils = DatasetUtils(config)
89
  vocab_size = len(self.data_utils.letter_list)
90
  word_embeddings = self.data_utils.embeddings
 
91
 
92
- stride = config["segment"]["stride"]
93
- window = config["segment"]["window"]
94
- min_window = config["segment"]["min-window"]
95
-
96
- segments, mapping = segment([text], stride, window, min_window)
97
-
98
- mapping_lines = []
99
- for sent_idx, seg_idx, word_idx, char_idx in mapping:
100
- mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"]
101
-
102
- self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines)
103
- self.original_lines = [text]
104
- self.segments = segments
105
-
106
  self.device = T.device(
107
  config['predictor'].get('device', 'cuda:0')
108
  if T.cuda.is_available() else 'cpu'
@@ -115,16 +113,39 @@ class Predictor:
115
  self.model.to(self.device)
116
  self.model.eval()
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  self.data_loader = DataLoader(
119
  DataRetriever(self.data_utils, segments),
120
- batch_size=config["predictor"].get("batch-size", 32),
121
  shuffle=False,
122
- num_workers=config['loader'].get('num-workers', 0),
123
  )
124
-
125
  class PredictTri(Predictor):
126
- def __init__(self, config, text):
127
- super().__init__(config, text)
128
  self.diacritics = {
129
  "FATHA": 1,
130
  "KASRA": 2,
@@ -146,11 +167,15 @@ class PredictTri(Predictor):
146
  diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
147
  return diacritized_lines
148
 
149
- def predict_partial(self, do_partial):
150
  outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial)
151
- y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics']
152
 
153
- diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
 
 
 
 
 
154
  return '\n'.join(diac_lines)
155
 
156
  def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
 
12
  import torch as T
13
  from torch.utils.data import DataLoader
14
 
15
+ from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head
16
  from model_partial import PartialDD
17
  from model_dd import DiacritizerD2
18
  from data_utils import DatasetUtils
 
31
  diacs: Union[np.ndarray, T.Tensor]
32
  ):
33
  line_w_diacs = ""
34
+ ts, tw = diacs.shape
35
+ diacs = diacs.flatten()
36
+ diacs_h3 = flat2_3head(diacs)
37
+ diacs_h3 = tuple(x.reshape(ts, tw) for x in diacs_h3)
38
+ diac_char_idx = 0
39
+ diac_word_idx = 0
40
+ for ch in line:
41
  line_w_diacs += ch
42
+ if ch == " ":
43
+ diac_char_idx = 0
44
+ diac_word_idx += 1
45
+ else:
46
+ tashkeel = (diacs_h3[0][diac_word_idx][diac_char_idx], diacs_h3[1][diac_word_idx][diac_char_idx], diacs_h3[2][diac_word_idx][diac_char_idx])
47
+ diac_char_idx += 1
48
+ line_w_diacs += shakkel_char(*tashkeel)
49
  return line_w_diacs
50
 
51
  def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1):
 
91
  line = apply_tashkeel(line, line_diacs)
92
  output.append(line)
93
 
94
+ return output
95
 
96
  class Predictor:
97
+ def __init__(self, config):
98
 
99
  self.data_utils = DatasetUtils(config)
100
  vocab_size = len(self.data_utils.letter_list)
101
  word_embeddings = self.data_utils.embeddings
102
+ self.config = config
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  self.device = T.device(
105
  config['predictor'].get('device', 'cuda:0')
106
  if T.cuda.is_available() else 'cpu'
 
113
  self.model.to(self.device)
114
  self.model.eval()
115
 
116
+ def create_dataloader(self, text, do_partial, do_hard_mask, threshold):
117
+ self.threshold = threshold
118
+ self.do_hard_mask = do_hard_mask
119
+
120
+ stride = self.config["segment"]["stride"]
121
+ window = self.config["segment"]["window"]
122
+ min_window = self.config["segment"]["min-window"]
123
+ if self.do_hard_mask or not do_partial:
124
+ segments, mapping = segment([text], stride, window, min_window)
125
+
126
+ mapping_lines = []
127
+ for sent_idx, seg_idx, word_idx, char_idx in mapping:
128
+ mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"]
129
+
130
+ self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines)
131
+ self.original_lines = [text]
132
+ self.segments = segments
133
+ else:
134
+ segments = text.split('\n')
135
+
136
+ self.segments = segments
137
+ self.original_lines = text.split('\n')
138
+
139
  self.data_loader = DataLoader(
140
  DataRetriever(self.data_utils, segments),
141
+ batch_size=self.config["predictor"].get("batch-size", 32),
142
  shuffle=False,
143
+ num_workers=self.config['loader'].get('num-workers', 0),
144
  )
145
+
146
  class PredictTri(Predictor):
147
+ def __init__(self, config):
148
+ super().__init__(config)
149
  self.diacritics = {
150
  "FATHA": 1,
151
  "KASRA": 2,
 
167
  diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
168
  return diacritized_lines
169
 
170
+ def predict_partial(self, do_partial, lines):
171
  outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial)
 
172
 
173
+ if self.do_hard_mask or not do_partial:
174
+ y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics']
175
+ diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
176
+ else:
177
+ diac_lines = diac_text(lines, outputs["other"][1], outputs["other"][0], selection_mode='1', threshold=self.threshold)
178
+
179
  return '\n'.join(diac_lines)
180
 
181
  def predict_majority_vote_context_contrastive(self, overwrite_cache=False):