Sa2VA-simple-demo / vlm /engine /hooks /dataset_info_hook.py
fffiloni's picture
Migrated from GitHub
d59f323 verified
from mmengine.hooks import Hook
from xtuner.registry import BUILDER
class SpecialDatasetInfoHook(Hook):
def __init__(self, tokenizer, is_intern_repo_dataset=False, special_tokens=None):
self.tokenizer = BUILDER.build(tokenizer)
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.is_intern_repo_dataset = is_intern_repo_dataset
def log(self, runner, dataset, mode='train'):
def _log(input_ids, log_prefix=''):
if self.is_intern_repo_dataset:
input_ids = [abs(x) for x in input_ids]
text = self.tokenizer.decode(input_ids)
runner.logger.info(text)
runner.logger.info(f'Num {mode} samples {len(dataset)}')
runner.logger.info(f'{mode} example:')
if 'chosen_ids' in dataset[0]:
_log(dataset[0]['chosen_ids'], log_prefix='chosen: ')
_log(dataset[0]['rejected_ids'], log_prefix='rejected: ')
else:
_log(dataset[0]['input_ids'])
def before_train(self, runner) -> None:
do_train = runner.train_loop is not None
do_eval = runner.val_loop is not None
if do_train:
train_dataset = runner.train_dataloader.dataset
self.log(runner, train_dataset, mode='train')
if do_eval:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_val(self, runner) -> None:
eval_dataset = runner.val_dataloader.dataset
self.log(runner, eval_dataset, mode='eval')
def before_test(self, runner) -> None:
test_dataset = runner.test_dataloader.dataset
self.log(runner, test_dataset, mode='test')