indiejoseph's picture
Upload 2 files
cd13b6c
from collections.abc import Callable
import traceback
from typing import List, Union
from datasets import Dataset
import re
import pickle
import os
from transformers.pipelines.pt_utils import KeyDataset
from transformers import AutoTokenizer
from tqdm.auto import tqdm
URL_REGEX = r"\b(https?://\S+)\b"
EMAIL_REGEX = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
TAG_REGEX = r"<[^>]+>"
HANDLE_REGEX = r"[^a-zA-Z](@\w+)"
class Translator:
def __init__(
self,
pipe: Callable,
max_length: int = 500,
batch_size: int = 16,
save_every_step=100,
text_key="text",
save_filename=None,
replace_chinese_puncts=False,
verbose=False,
):
self.pipe = pipe
self.max_length = max_length
self.batch_size = batch_size
self.save_every_step = save_every_step
self.save_filename = save_filename
self.text_key = text_key
self.replace_chinese_puncts = replace_chinese_puncts
self.verbose = verbose
if max_length == None and hasattr(pipe.model.config, "max_length"):
self.max_length = pipe.model.config.max_length
def _is_chinese(self, text: str) -> bool:
return (
re.search(
r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002ebef\U00030000-\U000323af\ufa0e\ufa0f\ufa11\ufa13\ufa14\ufa1f\ufa21\ufa23\ufa24\ufa27\ufa28\ufa29\u3006\u3007][\ufe00-\ufe0f\U000e0100-\U000e01ef]?",
text,
)
is not None
)
def _split_sentences(self, text: str) -> List[str]:
tokens = self.pipe.tokenizer(text, add_special_tokens=False)
token_size = len(tokens.input_ids)
if len(text) <= self.max_length:
return [text]
delimiter = set()
delimiter.update("。!?;…!?;")
sent_list = []
sent = text
while token_size > self.max_length:
orig_sent_len = token_size
# find the index of delimiter near the max_length
for i in range(token_size - 2, 0, -1):
token = tokens.token_to_chars(0, i)
char = sent[token.start : token.end]
if char in delimiter:
split_char_index = token.end
next_sent = sent[split_char_index:]
if len(next_sent) == 1:
continue
sent_list = [next_sent] + sent_list
sent = sent[0:split_char_index]
break
tokens = self.pipe.tokenizer(sent, add_special_tokens=False)
token_size = len(tokens.input_ids)
# no delimiter found, leave the sentence as it is
if token_size == orig_sent_len:
sent_list = [sent] + sent_list
sent = ""
break
if len(sent) > 0:
sent_list = [sent] + sent_list
return sent_list
def _preprocess(self, text: str) -> (str, str):
# extract entities
tags = re.findall(TAG_REGEX, text)
handles = re.findall(HANDLE_REGEX, text)
urls = re.findall(URL_REGEX, text)
emails = re.findall(EMAIL_REGEX, text)
entities = urls + emails + tags + handles
# TODO: escape entity placeholders
for i, entity in enumerate(entities):
text = text.replace(entity, "eeee[%d]" % i, 1)
lines = text.split("\n")
sentences = []
num_tokens = []
template = text.replace("{", "{{").replace("}", "}}")
chunk_index = 0
for line in lines:
sentence = line.strip()
if len(sentence) > 0 and self._is_chinese(sentence):
chunks = self._split_sentences(sentence)
for chunk in chunks:
sentences.append(chunk)
tokens = self.pipe.tokenizer(chunk, add_special_tokens=False)
num_tokens.append(len(tokens.input_ids))
chunk = chunk.replace("{", "{{").replace("}", "}}")
template = template.replace(chunk, "{%d}" % chunk_index, 1)
chunk_index += 1
return sentences, template, num_tokens, entities
def _postprocess(
self,
template: str,
src_sentences: List[str],
translations: List[str],
entities: List[str],
) -> str:
processed = []
alphanumeric_regex = re.compile(
"([a-zA-Za-zA-Z0-9\d+'\",,(\()\)::;;“”。·\.\??\!!‘’$\[\]<>/]+)"
)
def hash_text(text: List[str]) -> str:
text = "|".join(text)
puncts_map = str.maketrans(",;:()。?!“”‘’", ",;:().?!\"\"''")
text = text.translate(puncts_map)
return text.lower()
for i, p in enumerate(translations):
src_sentence = src_sentences[i]
if self.replace_chinese_puncts:
p = re.sub(",", ",", p) # replace all commas
p = re.sub(";", ";", p) # replace semi-colon
p = re.sub(":", ":", p) # replace colon
p = re.sub("\(", "(", p) # replace round basket
p = re.sub("\)", ")", p) # replace round basket
p = re.sub(r"([\d]),([\d])", r"\1,\2", p)
src_matches = re.findall(alphanumeric_regex, src_sentence)
tgt_matches = re.findall(alphanumeric_regex, p)
# length not match or no match
if (
len(src_matches) != len(tgt_matches)
or len(src_matches) == 0
or len(tgt_matches) == 0
):
processed.append(p)
continue
# normalize full-width to half-width and lower case
src_hashes = hash_text(src_matches)
translated_hashes = hash_text(tgt_matches)
if src_hashes != translated_hashes:
# fix unmatched
for j in range(len(src_matches)):
if src_matches[j] != tgt_matches[j]:
p = p.replace(tgt_matches[j], src_matches[j], 1)
processed.append(p)
output = template.format(*processed)
# replace entities
for i, entity in enumerate(entities):
output = output.replace("eeee[%d]" % i, entity, 1)
# TODO: unescape entity placeholders
# fix repeated punctuations
output = re.sub(r"([「」()『』《》。,:])\1+", r"\1", output)
# fix brackets
if "“" in output:
output = re.sub("“", "「", output)
if "”" in output:
output = re.sub("”", "」", output)
return output
def _save(self, translations):
with open(self.save_filename, "wb") as f:
pickle.dump(translations, f)
def __call__(self, inputs: Union[List[str], Dataset]) -> List[str]:
templates = []
sentences = []
num_tokens = []
sentence_indices = []
outputs = []
translations = []
entities_list = []
resume_from_file = None
if isinstance(inputs, Dataset):
ds = inputs
else:
if isinstance(inputs, str):
inputs = [inputs]
ds = Dataset.from_list([{"text": text} for text in inputs])
for i, text_input in tqdm(
enumerate(ds), total=len(ds), desc="Preprocessing", disable=not self.verbose
):
chunks, template, num_tokens, entities = self._preprocess(
text_input["text"]
)
templates.append(template)
sentence_indices.append([])
entities_list.append(entities)
for j, chunk in enumerate(chunks):
sentences.append(chunk)
sentence_indices[len(sentence_indices) - 1].append(len(sentences) - 1)
num_tokens.append(num_tokens[j])
if self.save_filename:
resume_from_file = (
self.save_filename if os.path.isfile(self.save_filename) else None
)
if resume_from_file != None:
translations = pickle.load(open(resume_from_file, "rb"))
if self.verbose:
print("translated:", len(translations))
print("to translate:", len(sentences) - len(translations))
if resume_from_file != None:
print(
"Resuming from {}({} records)".format(
resume_from_file, len(translations)
)
)
ds = Dataset.from_list(
[{"text": text} for text in sentences[len(translations) :]]
)
max_token_length = max(num_tokens)
if self.verbose:
print("Max Length:", max_token_length)
total_records = len(ds)
if total_records > 0:
step = 0
with tqdm(
disable=not self.verbose, desc="Translating", total=total_records
) as pbar:
for out in self.pipe(
KeyDataset(ds, self.text_key),
batch_size=self.batch_size,
max_length=self.max_length,
):
translations.append(out[0])
# export generate result every n steps
if (
step != 0
and self.save_filename != None
and step % self.save_every_step == 0
):
self._save(translations)
step += 1
pbar.update(1)
if self.save_filename != None and total_records > 0:
self._save(translations)
for i, template in tqdm(
enumerate(templates),
total=len(templates),
desc="Postprocessing",
disable=not self.verbose,
):
try:
src_sentences = [sentences[index] for index in sentence_indices[i]]
tgt_sentences = [
translations[index]["translation_text"]
for index in sentence_indices[i]
]
output = self._postprocess(
template, src_sentences, tgt_sentences, entities_list[i]
)
outputs.append(output)
except Exception as error:
print(error)
print(template)
traceback.print_exc()
# print(template, sentence_indices[i], len(translations))
return outputs
class Object(object):
pass
class FakePipe(object):
def __init__(self, max_length: int = 500):
self.model = Object()
self.model.config = Object()
self.model.config.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(
"indiejoseph/bart-translation-zh-yue"
)
def __call__(self, text: List[str], batch_size: str, max_length: int):
for i in range(len(text)):
sentence = text[i]
# extract entities
tags = re.findall(TAG_REGEX, sentence)
handles = re.findall(HANDLE_REGEX, sentence)
urls = re.findall(URL_REGEX, sentence)
emails = re.findall(EMAIL_REGEX, sentence)
entities = urls + emails + tags + handles
for i, entity in enumerate(entities):
sentence = sentence.replace(entity, "eeee[%d]" % i, 1)
if "123" in sentence:
yield [{"translation_text": sentence.replace("123", "123")}]
continue
if "abc" in sentence:
yield [{"translation_text": sentence.replace("abc", "ABC")}]
continue
if "Acetaminophen" in sentence:
yield [
{
"translation_text": sentence.replace(
"Acetaminophen", "ACEtaminidien"
)
}
]
continue
yield [{"translation_text": sentence}]
if __name__ == "__main__":
fake_pipe = FakePipe(60)
translator = Translator(fake_pipe, max_length=60, batch_size=2, verbose=True)
text1 = "对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人:"
text2 = """对于编写聊天机器人的脚本,你可以采用不同的方法,包括使用基于规则的系统、自然语言处理(NLP)技术和机器学习模型。下面是一个简单的例子,展示如何使用基于规则的方法来构建一个简单的聊天机器人:
```
# 设置用于匹配输入的关键字,并定义相应的回答数据字典。
keywords = {'你好': '你好!很高兴见到你。',
'再见': '再见!有机会再聊。',
'你叫什么': '我是一个聊天机器人。',
'你是谁': '我是一个基于人工智能技术制作的聊天机器人。'}
# 定义用于处理用户输入的函数。
def chatbot(input_text):
# 遍历关键字数据字典,匹配用户的输入。
for key in keywords:
if key in input_text:
# 如果匹配到了关键字,返回相应的回答。
return keywords[key]
# 如果没有找到匹配的关键字,返回默认回答。
return "对不起,我不知道你在说什么。"
# 运行聊天机器人。
while True:
# 获取用户输入。
user_input = input('用户: ')
# 如果用户输入“再见”,退出程序。
if user_input == '再见':
break
# 处理用户输入,并打印回答。
print('机器人: ' + chatbot(user_input))
```
这是一个非常简单的例子。对于实用的聊天机器人,可能需要使用更复杂的 NLP 技术和机器学习模型,以更好地理解和回答用户的问题。"""
text3 = "布洛芬(Ibuprofen)同撲熱息痛(Acetaminophen)係兩種常見嘅非處方藥,用於緩解疼痛、發燒同關節痛。"
text4 = "123 “abc” def's http://www.google.com [email protected] @abc 網址:http://localhost/abc下載"
text5 = "新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」"
outputs = translator([text1, text2, text3, text4, text5])
# for i, line in enumerate(outputs[1].split("\n")):
# input_text = text2.split("\n")[i]
# if line != input_text:
# print(line, text2.split("\n")[i])
assert outputs[0] == text1
assert outputs[1] == text2.replace("“", "「").replace("”", "」")
assert outputs[2] == text3
assert outputs[3] == text4.replace("“", "「").replace("”", "」")
assert outputs[4] == text5
# exception
assert (
len(
translator._split_sentences(
"新力公司董事長盛田昭夫、自民黨國會議員石原慎太郎等人撰寫嘅《日本可以說「不」》、《日本還要說「不」》、《日本堅決說「不」》三本書中話道:「無啦啦挑起戰爭嘅好戰日本人,製造南京大屠殺嘅殘暴嘅日本人,呢d就係人地對日本人嘅兩個誤解,都係‘敲打日本’嘅兩個根由,我地必須採取措施消除佢。」"
)
)
== 1
)
translator = Translator(fake_pipe, max_length=5, batch_size=2, verbose=True)
assert (
len(
translator._split_sentences("====。====。====。====。====。====。====。====。====。")
)
== 9
)