|
import numpy as np |
|
import torch as T |
|
|
|
from tqdm import tqdm |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from components.k_lstm import K_LSTM |
|
from components.attention import Attention |
|
from data_utils import DatasetUtils |
|
from diac_utils import flat2_3head, flat_2_3head |
|
|
|
class DiacritizerD2(nn.Module): |
|
def __init__(self, config): |
|
super(DiacritizerD2, self).__init__() |
|
self.max_word_len = config["train"]["max-word-len"] |
|
self.max_sent_len = config["train"]["max-sent-len"] |
|
self.char_embed_dim = config["train"]["char-embed-dim"] |
|
|
|
self.final_dropout_p = config["train"]["final-dropout"] |
|
self.sent_dropout_p = config["train"]["sent-dropout"] |
|
self.diac_dropout_p = config["train"]["diac-dropout"] |
|
self.vertical_dropout = config['train']['vertical-dropout'] |
|
self.recurrent_dropout = config['train']['recurrent-dropout'] |
|
self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied') |
|
self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid') |
|
|
|
self.sent_lstm_units = config["train"]["sent-lstm-units"] |
|
self.word_lstm_units = config["train"]["word-lstm-units"] |
|
self.decoder_units = config["train"]["decoder-units"] |
|
|
|
self.sent_lstm_layers = config["train"]["sent-lstm-layers"] |
|
self.word_lstm_layers = config["train"]["word-lstm-layers"] |
|
|
|
self.cell = config['train'].get('rnn-cell', 'lstm') |
|
self.num_layers = config["train"].get("num-layers", 2) |
|
self.RNN_Layer = K_LSTM |
|
|
|
self.batch_first = config['train'].get('batch-first', True) |
|
self.device = 'cuda' if T.cuda.is_available() else 'cpu' |
|
self.num_classes = 15 |
|
|
|
def build(self, wembs: T.Tensor, abjad_size: int): |
|
self.closs = F.cross_entropy |
|
self.bloss = F.binary_cross_entropy_with_logits |
|
|
|
rnn_kargs = dict( |
|
recurrent_dropout_mode=self.recurrent_dropout_mode, |
|
recurrent_activation=self.recurrent_activation, |
|
) |
|
|
|
self.sent_lstm = self.RNN_Layer( |
|
input_size=300, |
|
hidden_size=self.sent_lstm_units, |
|
num_layers=self.sent_lstm_layers, |
|
bidirectional=True, |
|
vertical_dropout=self.vertical_dropout, |
|
recurrent_dropout=self.recurrent_dropout, |
|
batch_first=self.batch_first, |
|
**rnn_kargs, |
|
) |
|
|
|
self.word_lstm = self.RNN_Layer( |
|
input_size=self.sent_lstm_units * 2 + self.char_embed_dim, |
|
hidden_size=self.word_lstm_units, |
|
num_layers=self.word_lstm_layers, |
|
bidirectional=True, |
|
vertical_dropout=self.vertical_dropout, |
|
recurrent_dropout=self.recurrent_dropout, |
|
batch_first=self.batch_first, |
|
return_states=True, |
|
**rnn_kargs, |
|
) |
|
|
|
self.char_embs = nn.Embedding( |
|
abjad_size, |
|
self.char_embed_dim, |
|
padding_idx=0, |
|
) |
|
|
|
self.attention = Attention( |
|
kind="dot", |
|
query_dim=self.word_lstm_units * 2, |
|
input_dim=self.sent_lstm_units * 2, |
|
) |
|
|
|
self.word_embs = T.tensor(wembs).clone().to(dtype=T.float32) |
|
self.word_embs = self.word_embs.to(self.device) |
|
|
|
self.classifier = nn.Linear(self.attention.Dout + self.word_lstm_units * 2, self.num_classes) |
|
self.dropout = nn.Dropout(self.final_dropout_p) |
|
|
|
def forward(self, sents, words, labels=None, subword_lengths=None): |
|
|
|
|
|
|
|
max_words = min(self.max_sent_len, sents.shape[1]) |
|
|
|
word_mask = words.ne(0.).float() |
|
|
|
|
|
if self.training: |
|
q = 1.0 - self.sent_dropout_p |
|
sdo = T.bernoulli(T.full(sents.shape, q)) |
|
sents_do = sents * sdo.long() |
|
|
|
wembs = self.word_embs[sents_do] |
|
|
|
else: |
|
wembs = self.word_embs[sents] |
|
|
|
|
|
sent_enc = self.sent_lstm(wembs.to(self.device)) |
|
|
|
|
|
sentword_do = sent_enc.unsqueeze(2) |
|
|
|
|
|
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) |
|
|
|
|
|
word_index = words.view(-1, self.max_word_len) |
|
|
|
|
|
cembs = self.char_embs(word_index) |
|
|
|
|
|
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) |
|
|
|
|
|
char_embs = T.cat([cembs, sentword_do], dim=-1) |
|
|
|
|
|
char_enc, _ = self.word_lstm(char_embs) |
|
|
|
|
|
char_enc_reshaped = char_enc.view(-1, max_words, self.max_word_len, self.word_lstm_units * 2) |
|
|
|
|
|
omit_self_mask = (1.0 - T.eye(max_words)).unsqueeze(0).to(self.device) |
|
attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) |
|
|
|
|
|
attn_enc = attn_enc.reshape(-1, self.max_word_len, self.attention.Dout) |
|
|
|
|
|
final_vec = T.cat([attn_enc, char_enc], dim=-1) |
|
|
|
diac_out = self.classifier(self.dropout(final_vec)) |
|
|
|
|
|
diac_out = diac_out.view(-1, max_words, self.max_word_len, self.num_classes) |
|
|
|
|
|
if not self.batch_first: |
|
diac_out = diac_out.swapaxes(1, 0) |
|
|
|
return diac_out |
|
|
|
|
|
def step(self, xt, yt, mask=None): |
|
xt[1] = xt[1].to(self.device) |
|
xt[2] = xt[2].to(self.device) |
|
|
|
yt = yt.to(self.device) |
|
|
|
|
|
diac, _ = self(*xt) |
|
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1)) |
|
|
|
return loss |
|
|
|
def predict(self, dataloader): |
|
training = self.training |
|
self.eval() |
|
|
|
preds = {'haraka': [], 'shadda': [], 'tanween': []} |
|
print("> Predicting...") |
|
for inputs, _ in tqdm(dataloader, total=len(dataloader)): |
|
inputs[0] = inputs[0].to(self.device) |
|
inputs[1] = inputs[1].to(self.device) |
|
diac, _ = self(*inputs) |
|
|
|
output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1) |
|
|
|
|
|
haraka, tanween, shadda = flat_2_3head(output) |
|
|
|
preds['haraka'].extend(haraka) |
|
preds['tanween'].extend(tanween) |
|
preds['shadda'].extend(shadda) |
|
|
|
self.train(training) |
|
return ( |
|
np.array(preds['haraka']), |
|
np.array(preds["tanween"]), |
|
np.array(preds["shadda"]), |
|
) |
|
|
|
class DiacritizerD3(nn.Module): |
|
def __init__(self, config, device='cuda'): |
|
super(DiacritizerD3, self).__init__() |
|
self.max_word_len = config["train"]["max-word-len"] |
|
self.max_sent_len = config["train"]["max-sent-len"] |
|
self.char_embed_dim = config["train"]["char-embed-dim"] |
|
|
|
self.sent_dropout_p = config["train"]["sent-dropout"] |
|
self.diac_dropout_p = config["train"]["diac-dropout"] |
|
self.vertical_dropout = config['train']['vertical-dropout'] |
|
self.recurrent_dropout = config['train']['recurrent-dropout'] |
|
self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied') |
|
self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid') |
|
|
|
self.sent_lstm_units = config["train"]["sent-lstm-units"] |
|
self.word_lstm_units = config["train"]["word-lstm-units"] |
|
self.decoder_units = config["train"]["decoder-units"] |
|
|
|
self.sent_lstm_layers = config["train"]["sent-lstm-layers"] |
|
self.word_lstm_layers = config["train"]["word-lstm-layers"] |
|
|
|
self.cell = config['train'].get('rnn-cell', 'lstm') |
|
self.num_layers = config["train"].get("num-layers", 2) |
|
self.RNN_Layer = K_LSTM |
|
|
|
self.batch_first = config['train'].get('batch-first', True) |
|
|
|
self.baseline = config["train"].get("baseline", False) |
|
self.device = device |
|
|
|
def build(self, wembs: T.Tensor, abjad_size: int): |
|
self.closs = F.cross_entropy |
|
self.bloss = F.binary_cross_entropy_with_logits |
|
|
|
rnn_kargs = dict( |
|
recurrent_dropout_mode=self.recurrent_dropout_mode, |
|
recurrent_activation=self.recurrent_activation, |
|
) |
|
|
|
self.sent_lstm = self.RNN_Layer( |
|
input_size=300, |
|
hidden_size=self.sent_lstm_units, |
|
num_layers=self.sent_lstm_layers, |
|
bidirectional=True, |
|
vertical_dropout=self.vertical_dropout, |
|
recurrent_dropout=self.recurrent_dropout, |
|
batch_first=self.batch_first, |
|
**rnn_kargs, |
|
) |
|
|
|
self.word_lstm = self.RNN_Layer( |
|
input_size=self.sent_lstm_units * 2 + self.char_embed_dim, |
|
hidden_size=self.word_lstm_units, |
|
num_layers=self.word_lstm_layers, |
|
bidirectional=True, |
|
vertical_dropout=self.vertical_dropout, |
|
recurrent_dropout=self.recurrent_dropout, |
|
batch_first=self.batch_first, |
|
return_states=True, |
|
**rnn_kargs, |
|
) |
|
|
|
self.char_embs = nn.Embedding( |
|
abjad_size, |
|
self.char_embed_dim, |
|
padding_idx=0, |
|
) |
|
|
|
self.attention = Attention( |
|
kind="dot", |
|
query_dim=self.word_lstm_units * 2, |
|
input_dim=self.sent_lstm_units * 2, |
|
) |
|
|
|
self.lstm_decoder = self.RNN_Layer( |
|
input_size=self.word_lstm_units * 2 + self.attention.Dout + 8, |
|
hidden_size=self.word_lstm_units * 2, |
|
num_layers=1, |
|
bidirectional=False, |
|
vertical_dropout=self.vertical_dropout, |
|
recurrent_dropout=self.recurrent_dropout, |
|
batch_first=self.batch_first, |
|
return_states=True, |
|
**rnn_kargs, |
|
) |
|
|
|
self.word_embs = T.tensor(wembs, dtype=T.float32) |
|
|
|
self.classifier = nn.Linear(self.lstm_decoder.hidden_size, 15) |
|
self.dropout = nn.Dropout(0.2) |
|
|
|
def forward(self, sents, words, labels): |
|
|
|
|
|
|
|
|
|
word_mask = words.ne(0.).float() |
|
|
|
|
|
if self.training: |
|
q = 1.0 - self.sent_dropout_p |
|
sdo = T.bernoulli(T.full(sents.shape, q)) |
|
sents_do = sents * sdo.long() |
|
|
|
wembs = self.word_embs[sents_do] |
|
|
|
else: |
|
wembs = self.word_embs[sents] |
|
|
|
|
|
sent_enc = self.sent_lstm(wembs.to(self.device)) |
|
|
|
|
|
sentword_do = sent_enc.unsqueeze(2) |
|
|
|
|
|
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) |
|
|
|
|
|
word_index = words.view(-1, self.max_word_len) |
|
|
|
|
|
cembs = self.char_embs(word_index) |
|
|
|
|
|
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) |
|
|
|
|
|
char_embs = T.cat([cembs, sentword_do], dim=-1) |
|
|
|
|
|
char_enc, _ = self.word_lstm(char_embs) |
|
|
|
|
|
char_enc_reshaped = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units * 2) |
|
|
|
|
|
omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device) |
|
attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) |
|
|
|
|
|
attn_enc = attn_enc.view(-1, self.max_sent_len*self.max_word_len, self.attention.Dout) |
|
|
|
|
|
if self.training and self.diac_dropout_p > 0: |
|
q = 1.0 - self.diac_dropout_p |
|
ddo = T.bernoulli(T.full(labels.shape[:-1], q)) |
|
labels = labels * ddo.unsqueeze(-1).long().to(self.device) |
|
|
|
|
|
labels = labels.view(-1, self.max_sent_len*self.max_word_len, 8).float() |
|
|
|
|
|
char_enc = char_enc.view(-1, self.max_sent_len*self.max_word_len, self.word_lstm_units * 2) |
|
|
|
final_vec = T.cat([attn_enc, char_enc, labels], dim=-1) |
|
|
|
|
|
dec_out, _ = self.lstm_decoder(final_vec) |
|
|
|
|
|
dec_out = dec_out.reshape(-1, self.max_word_len, self.lstm_decoder.hidden_size) |
|
|
|
diac_out = self.classifier(self.dropout(dec_out)) |
|
|
|
|
|
diac_out = diac_out.view(-1, self.max_sent_len, self.max_word_len, 15) |
|
|
|
|
|
if not self.batch_first: |
|
diac_out = diac_out.swapaxes(1, 0) |
|
|
|
return diac_out, attn_map |
|
|
|
def predict_sample(self, sents, words, labels): |
|
|
|
word_mask = words.ne(0.).float() |
|
|
|
|
|
if self.training: |
|
q = 1.0 - self.sent_dropout_p |
|
sdo = T.bernoulli(T.full(sents.shape, q)) |
|
sents_do = sents * sdo.long() |
|
|
|
wembs = self.word_embs[sents_do] |
|
|
|
else: |
|
wembs = self.word_embs[sents] |
|
|
|
|
|
sent_enc = self.sent_lstm(wembs.to(self.device)) |
|
|
|
|
|
sentword_do = sent_enc.unsqueeze(2) |
|
|
|
|
|
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1)) |
|
|
|
|
|
word_index = words.view(-1, self.max_word_len) |
|
|
|
|
|
cembs = self.char_embs(word_index) |
|
|
|
|
|
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2) |
|
|
|
|
|
char_embs = T.cat([cembs, sentword_do], dim=-1) |
|
|
|
|
|
char_enc, _ = self.word_lstm(char_embs) |
|
|
|
|
|
|
|
char_enc = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units*2) |
|
|
|
|
|
omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device) |
|
attn_enc, _ = self.attention(char_enc, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask) |
|
|
|
|
|
all_out = T.zeros(*char_enc.size()[:-1], 15).to(self.device) |
|
|
|
|
|
batch_sz = char_enc.size()[0] |
|
|
|
|
|
zeros = T.zeros(1, batch_sz, self.lstm_decoder.hidden_size).to(self.device) |
|
|
|
|
|
bos_tag = T.tensor([0,0,0,0,0,1,0,0]).unsqueeze(0) |
|
|
|
|
|
prev_label = T.cat([bos_tag]*batch_sz).to(self.device).float() |
|
|
|
|
|
|
|
for ts in range(self.max_sent_len): |
|
dec_hx = (zeros, zeros) |
|
|
|
for tw in range(self.max_word_len): |
|
final_vec = T.cat([attn_enc[:,ts,tw,:], char_enc[:,ts,tw,:], prev_label], dim=-1).unsqueeze(1) |
|
|
|
dec_out, dec_hx = self.lstm_decoder(final_vec, dec_hx) |
|
|
|
dec_out = dec_out.squeeze(0) |
|
dec_out = dec_out.transpose(0,1) |
|
|
|
logits_raw = self.classifier(self.dropout(dec_out)) |
|
|
|
|
|
out_idx = T.max(T.softmax(logits_raw.squeeze(), dim=-1), dim=-1)[1] |
|
|
|
haraka, tanween, shadda = flat2_3head(out_idx.detach().cpu().numpy()) |
|
|
|
haraka_onehot = T.eye(6)[haraka].float().to(self.device) |
|
|
|
|
|
tanween = T.tensor(tanween).float().unsqueeze(-1).to(self.device) |
|
shadda = T.tensor(shadda).float().unsqueeze(-1).to(self.device) |
|
|
|
prev_label = T.cat([haraka_onehot, tanween, shadda], dim=-1) |
|
|
|
all_out[:,ts,tw,:] = logits_raw.squeeze() |
|
|
|
if not self.batch_first: |
|
all_out = all_out.swapaxes(1, 0) |
|
|
|
return all_out |
|
|
|
def step(self, xt, yt, mask=None): |
|
xt[1] = xt[1].to(self.device) |
|
xt[2] = xt[2].to(self.device) |
|
|
|
yt = yt.to(self.device) |
|
|
|
if self.training: |
|
diac, _ = self(*xt) |
|
else: |
|
diac = self.predict_sample(*xt) |
|
|
|
|
|
loss = self.closs(diac.view(-1,15), yt.view(-1)) |
|
return loss |
|
|
|
def predict(self, dataloader): |
|
training = self.training |
|
self.eval() |
|
|
|
preds = {'haraka': [], 'shadda': [], 'tanween': []} |
|
print("> Predicting...") |
|
for inputs, _ in tqdm(dataloader, total=len(dataloader)): |
|
inputs[1] = inputs[1].to(self.device) |
|
inputs[2] = inputs[2].to(self.device) |
|
diac = self.predict_sample(*inputs) |
|
output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1) |
|
|
|
|
|
haraka, tanween, shadda = flat_2_3head(output) |
|
|
|
preds['haraka'].extend(haraka) |
|
preds['tanween'].extend(tanween) |
|
preds['shadda'].extend(shadda) |
|
|
|
self.train(training) |
|
return ( |
|
np.array(preds['haraka']), |
|
np.array(preds["tanween"]), |
|
np.array(preds["shadda"]), |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
import yaml |
|
config_path = "configs/dd/config_d2.yaml" |
|
model_path = "models/tashkeela-d2.pt" |
|
with open(config_path, 'r', encoding="utf-8") as file: |
|
config = yaml.load(file, Loader=yaml.FullLoader) |
|
|
|
data_utils = DatasetUtils(config) |
|
vocab_size = len(data_utils.letter_list) |
|
word_embeddings = data_utils.embeddings |
|
|
|
model = DiacritizerD2(config, device='cpu') |
|
model.build(word_embeddings, vocab_size) |
|
model.load_state_dict(T.load(model_path, map_location=T.device('cpu'))["state_dict"]) |