Spaces:
Runtime error
Runtime error
import collections | |
import dataclasses | |
import types | |
import pytorch_lightning as pl | |
import torch.utils.data | |
import transformers | |
from data import ( | |
generate_annotated_images, | |
get_annotation_ground_truth_str, | |
DataItem, | |
get_extra_tokens, | |
Batch, | |
Split, | |
BatchCollateFunction, | |
) | |
from utils import load_pickle_or_build_object_and_save | |
class Model: | |
processor: transformers.models.donut.processing_donut.DonutProcessor | |
tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast | |
encoder_decoder: transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder.VisionEncoderDecoderModel | |
batch_collate_function: BatchCollateFunction | |
config: types.SimpleNamespace | |
def add_unknown_tokens_to_tokenizer( | |
tokenizer, encoder_decoder, unknown_tokens: list[str] | |
): | |
tokenizer.add_tokens(unknown_tokens) | |
encoder_decoder.decoder.resize_token_embeddings(len(tokenizer)) | |
def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter: | |
unknown_tokens_counter = collections.Counter() | |
for annotated_image in generate_annotated_images(): | |
ground_truth = get_annotation_ground_truth_str(annotated_image.annotation) | |
input_ids = tokenizer(ground_truth).input_ids | |
tokens = tokenizer.tokenize(ground_truth, add_special_tokens=True) | |
for token_id, token in zip(input_ids, tokens, strict=True): | |
if token_id == tokenizer.unk_token_id: | |
unknown_tokens_counter.update([token]) | |
return unknown_tokens_counter | |
def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation( | |
tokenizer, token_ids | |
): | |
token_ids[token_ids == tokenizer.pad_token_id] = -100 | |
return token_ids | |
class BatchCollateFunction: | |
processor: transformers.models.donut.processing_donut.DonutProcessor | |
tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast | |
decoder_sequence_max_length: int | |
def __call__(self, batch: list[DataItem], split: Split) -> Batch: | |
images = [di.image for di in batch] | |
images = self.processor( | |
images, random_padding=split == Split.train, return_tensors="pt" | |
).pixel_values | |
target_token_ids = self.tokenizer( | |
[di.target_string for di in batch], | |
add_special_tokens=False, | |
max_length=self.decoder_sequence_max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
).input_ids | |
labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation( | |
self.tokenizer, target_token_ids | |
) | |
data_indices = [di.data_index for di in batch] | |
return Batch(images=images, labels=labels, data_indices=data_indices) | |
def build_model(config: types.SimpleNamespace or object) -> Model: | |
donut_processor = transformers.DonutProcessor.from_pretrained( | |
config.pretrained_model_name | |
) | |
donut_processor.image_processor.size = dict( | |
width=config.image_width, height=config.image_height | |
) | |
donut_processor.image_processor.do_align_long_axis = False | |
tokenizer = donut_processor.tokenizer | |
encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained( | |
config.pretrained_model_name | |
) | |
encoder_decoder_config.encoder.image_size = ( | |
config.image_width, | |
config.image_height, | |
) | |
encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained( | |
config.pretrained_model_name, config=encoder_decoder_config | |
) | |
encoder_decoder_config.pad_token_id = tokenizer.pad_token_id | |
encoder_decoder_config.decoder_start_token_id = tokenizer.convert_tokens_to_ids( | |
get_extra_tokens().benetech_prompt | |
) | |
encoder_decoder_config.bos_token_id = encoder_decoder_config.decoder_start_token_id | |
encoder_decoder_config.eos_token_id = tokenizer.convert_tokens_to_ids( | |
get_extra_tokens().benetech_prompt_end | |
) | |
extra_tokens = list(get_extra_tokens().__dict__.values()) | |
add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, extra_tokens) | |
unknown_dataset_tokens = load_pickle_or_build_object_and_save( | |
config.unknown_tokens_for_tokenizer_path, | |
lambda: list(find_unknown_tokens_for_tokenizer(tokenizer).keys()), | |
) | |
add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, unknown_dataset_tokens) | |
tokenizer.eos_token_id = encoder_decoder_config.eos_token_id | |
batch_collate_function = BatchCollateFunction( | |
processor=donut_processor, | |
tokenizer=tokenizer, | |
decoder_sequence_max_length=config.decoder_sequence_max_length, | |
) | |
return Model( | |
processor=donut_processor, | |
tokenizer=tokenizer, | |
encoder_decoder=encoder_decoder, | |
batch_collate_function=batch_collate_function, | |
config=config, | |
) | |
def generate_token_strings( | |
model: Model, images: torch.Tensor, skip_special_tokens=True | |
) -> list[str]: | |
decoder_output = model.encoder_decoder.generate( | |
images, | |
max_length=10 | |
if model.config.debug | |
else model.config.decoder_sequence_max_length, | |
eos_token_id=model.tokenizer.eos_token_id, | |
return_dict_in_generate=True, | |
) | |
return model.tokenizer.batch_decode( | |
decoder_output.sequences, skip_special_tokens=skip_special_tokens | |
) | |
def predict_string(image, model: Model): | |
image = model.processor( | |
image, random_padding=False, return_tensors="pt" | |
).pixel_values | |
string = generate_token_strings(model, image)[0] | |
return string | |
class LightningModule(pl.LightningModule): | |
def __init__(self, config): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model = build_model(config) | |
self.encoder_decoder = self.model.encoder_decoder | |
def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor: | |
loss = self.compute_loss(batch) | |
self.log("train_loss", loss) | |
return loss | |
def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0): | |
loss = self.compute_loss(batch) | |
self.log("val_loss", loss) | |
def compute_loss(self, batch: Batch) -> torch.Tensor: | |
outputs = self.encoder_decoder(pixel_values=batch.images, labels=batch.labels) | |
loss = outputs.loss | |
return loss | |
def configure_optimizers(self) -> torch.optim.Optimizer: | |
optimizer = torch.optim.Adam( | |
self.parameters(), lr=self.hparams["config"].learning_rate | |
) | |
return optimizer | |