|
import os
|
|
from pprint import pprint
|
|
from configs.config import parser
|
|
from dataset.data_module import DataModule
|
|
from lightning_tools.callbacks import add_callbacks
|
|
from models.R2GenGPT import R2GenGPT
|
|
from lightning.pytorch import seed_everything
|
|
import lightning.pytorch as pl
|
|
|
|
|
|
def train(args):
|
|
dm = DataModule(args)
|
|
callbacks = add_callbacks(args)
|
|
|
|
trainer = pl.Trainer(
|
|
devices=args.devices,
|
|
num_nodes=args.num_nodes,
|
|
strategy=args.strategy,
|
|
accelerator=args.accelerator,
|
|
precision=args.precision,
|
|
val_check_interval = args.val_check_interval,
|
|
limit_val_batches = args.limit_val_batches,
|
|
max_epochs = args.max_epochs,
|
|
num_sanity_val_steps = args.num_sanity_val_steps,
|
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
|
callbacks=callbacks["callbacks"],
|
|
logger=callbacks["loggers"]
|
|
)
|
|
|
|
if args.ckpt_file is not None:
|
|
model = R2GenGPT.load_from_checkpoint(args.ckpt_file, strict=False)
|
|
else:
|
|
model = R2GenGPT(args)
|
|
|
|
if args.test:
|
|
trainer.test(model, datamodule=dm)
|
|
elif args.validate:
|
|
trainer.validate(model, datamodule=dm)
|
|
else:
|
|
trainer.fit(model, datamodule=dm)
|
|
|
|
def main():
|
|
args = parser.parse_args()
|
|
os.makedirs(args.savedmodel_path, exist_ok=True)
|
|
pprint(vars(args))
|
|
seed_everything(42, workers=True)
|
|
train(args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |