import os from pprint import pprint from configs.config import parser from dataset.data_module import DataModule from lightning_tools.callbacks import add_callbacks from models.R2GenGPT import R2GenGPT from lightning.pytorch import seed_everything import lightning.pytorch as pl def train(args): dm = DataModule(args) callbacks = add_callbacks(args) trainer = pl.Trainer( devices=args.devices, num_nodes=args.num_nodes, strategy=args.strategy, accelerator=args.accelerator, precision=args.precision, val_check_interval = args.val_check_interval, limit_val_batches = args.limit_val_batches, max_epochs = args.max_epochs, num_sanity_val_steps = args.num_sanity_val_steps, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=callbacks["callbacks"], logger=callbacks["loggers"] ) if args.ckpt_file is not None: model = R2GenGPT.load_from_checkpoint(args.ckpt_file, strict=False) else: model = R2GenGPT(args) if args.test: trainer.test(model, datamodule=dm) elif args.validate: trainer.validate(model, datamodule=dm) else: trainer.fit(model, datamodule=dm) def main(): args = parser.parse_args() os.makedirs(args.savedmodel_path, exist_ok=True) pprint(vars(args)) seed_everything(42, workers=True) train(args) if __name__ == '__main__': main()