bkhmsi's picture
restructured space
bb42b73
raw
history blame
4.18 kB
import os
import yaml
import gdown
import gradio as gr
from predict import PredictTri
from gradio import blocks
output_path = "tashkeela-d2.pt"
gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
if not os.path.exists(output_path):
model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
output_path = "vocab.vec"
if not os.path.exists(output_path):
vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
with open("config.yaml", 'r', encoding="utf-8") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
config["train"]["max-sent-len"] = config["predictor"]["window"]
config["train"]["max-token-count"] = config["predictor"]["window"] * 3
predictor = PredictTri(config)
def diacritze_full(text):
do_hard_mask = None
threshold = None
predictor.create_dataloader(text, False, do_hard_mask, threshold)
diacritized_lines = predictor.predict_partial(do_partial=False, lines=text.split('\n'))
return diacritized_lines
def diacritze_partial(text, mask_mode, threshold):
do_partial = True
predictor.create_dataloader(text, do_partial, mask_mode=="Hard", threshold)
diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
return diacritized_lines
with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
gr.Markdown(
"""
# Partial Diacritization: A Context-Contrastive Inference Approach
### Authors: Muhammad ElNokrashy, Badr AlKhamissi
### Paper Link: TBD
""")
with gr.Tab(label="Full Diacritization"):
full_input_txt = gr.Textbox(
placeholder="ุงูƒุชุจ ู‡ู†ุง",
lines=5,
label="Input",
type='text',
rtl=True,
text_align='right',
)
full_output_txt = gr.Textbox(
lines=5,
label="Output",
type='text',
rtl=True,
text_align='right',
show_copy_button=True,
)
full_btn = gr.Button(value="Shakkel")
full_btn.click(diacritze_full, inputs=[full_input_txt], outputs=[full_output_txt])
gr.Examples(
examples=[
"ูˆู„ูˆ ุญู…ู„ ู…ู† ู…ุฌู„ุณ ุงู„ุฎูŠุงุฑ ุŒ ูˆู„ู… ูŠู…ู†ุน ู…ู† ุงู„ูƒู„ุงู…"
],
inputs=full_input_txt,
outputs=full_output_txt,
fn=diacritze_full,
cache_examples=True,
)
with gr.Tab(label="Partial Diacritization") as partial_settings:
with gr.Row():
masking_mode = gr.Radio(choices=["Hard", "Soft"], value="Hard", label="Masking Mode")
threshold_slider = gr.Slider(label="Soft Masking Threshold", minimum=0, maximum=1, value=0.1)
partial_input_txt = gr.Textbox(
placeholder="ุงูƒุชุจ ู‡ู†ุง",
lines=5,
label="Input",
type='text',
rtl=True,
text_align='right',
)
partial_output_txt = gr.Textbox(
lines=5,
label="Output",
type='text',
rtl=True,
text_align='right',
show_copy_button=True,
)
partial_btn = gr.Button(value="Shakkel")
partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider], outputs=[partial_output_txt])
gr.Examples(
examples=[
["ูˆู„ูˆ ุญู…ู„ ู…ู† ู…ุฌู„ุณ ุงู„ุฎูŠุงุฑ ุŒ ูˆู„ู… ูŠู…ู†ุน ู…ู† ุงู„ูƒู„ุงู…", "Hard", 0],
],
inputs=[partial_input_txt, masking_mode, threshold_slider],
outputs=partial_output_txt,
fn=diacritze_partial,
cache_examples=True,
)
if __name__ == "__main__":
demo.queue().launch(
# share=False,
# debug=False,
# server_port=7860,
# server_name="0.0.0.0",
# ssl_verify=False,
# ssl_certfile="cert.pem",
# ssl_keyfile="key.pem"
)