3v324v23's picture
Add code
a84a65c
import logging
import os
import time
from shutil import copytree, ignore_patterns
import torch
from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter, summary
class LoggerWithTBoard(SummaryWriter):
def __init__(self, cfg):
# current time stamp and experiment log directory
self.start_time = time.strftime('%y-%m-%dT%H-%M-%S', time.localtime())
self.logdir = os.path.join(cfg.logdir, self.start_time)
# init tboard
super().__init__(self.logdir)
# backup the cfg
OmegaConf.save(cfg, os.path.join(self.log_dir, 'cfg.yaml'))
# backup the code state
if cfg.log_code_state:
dest_dir = os.path.join(self.logdir, 'code')
copytree(os.getcwd(), dest_dir, ignore=ignore_patterns(*cfg.patterns_to_ignore))
# init logger which handles printing and logging mostly same things to the log file
self.print_logger = logging.getLogger('main')
self.print_logger.setLevel(logging.INFO)
msgfmt = '[%(levelname)s] %(asctime)s - %(name)s \n %(message)s'
datefmt = '%d %b %Y %H:%M:%S'
formatter = logging.Formatter(msgfmt, datefmt)
# stdout
sh = logging.StreamHandler()
sh.setLevel(logging.DEBUG)
sh.setFormatter(formatter)
self.print_logger.addHandler(sh)
# log file
fh = logging.FileHandler(os.path.join(self.log_dir, 'log.txt'))
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
self.print_logger.addHandler(fh)
self.print_logger.info(f'Saving logs and checkpoints @ {self.logdir}')
def log_param_num(self, model):
param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.print_logger.info(f'The number of parameters: {param_num/1e+6:.3f} mil')
self.add_scalar('num_params', param_num, 0)
return param_num
def log_iter_loss(self, loss, iter, phase):
self.add_scalar(f'{phase}/loss_iter', loss, iter)
def log_epoch_loss(self, loss, epoch, phase):
self.add_scalar(f'{phase}/loss', loss, epoch)
self.print_logger.info(f'{phase} ({epoch}): loss {loss:.3f};')
def log_epoch_metrics(self, metrics_dict, epoch, phase):
for metric, val in metrics_dict.items():
self.add_scalar(f'{phase}/{metric}', val, epoch)
metrics_dict = {k: round(v, 4) for k, v in metrics_dict.items()}
self.print_logger.info(f'{phase} ({epoch}) metrics: {metrics_dict};')
def log_test_metrics(self, metrics_dict, hparams_dict, best_epoch):
allowed_types = (int, float, str, bool, torch.Tensor)
hparams_dict = {k: v for k, v in hparams_dict.items() if isinstance(v, allowed_types)}
metrics_dict = {f'test/{k}': round(v, 4) for k, v in metrics_dict.items()}
exp, ssi, sei = summary.hparams(hparams_dict, metrics_dict)
self.file_writer.add_summary(exp)
self.file_writer.add_summary(ssi)
self.file_writer.add_summary(sei)
for k, v in metrics_dict.items():
self.add_scalar(k, v, best_epoch)
self.print_logger.info(f'test ({best_epoch}) metrics: {metrics_dict};')
def log_best_model(self, model, loss, epoch, optimizer, metrics_dict):
model_name = model.__class__.__name__
self.best_model_path = os.path.join(self.logdir, f'{model_name}-{self.start_time}.pt')
checkpoint = {
'loss': loss,
'metrics': metrics_dict,
'epoch': epoch,
'optimizer': optimizer.state_dict(),
'model': model.state_dict(),
}
torch.save(checkpoint, self.best_model_path)
self.print_logger.info(f'Saved model in {self.best_model_path}')