Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
import time | |
from shutil import copytree, ignore_patterns | |
import torch | |
from omegaconf import OmegaConf | |
from torch.utils.tensorboard import SummaryWriter, summary | |
class LoggerWithTBoard(SummaryWriter): | |
def __init__(self, cfg): | |
# current time stamp and experiment log directory | |
self.start_time = time.strftime('%y-%m-%dT%H-%M-%S', time.localtime()) | |
self.logdir = os.path.join(cfg.logdir, self.start_time) | |
# init tboard | |
super().__init__(self.logdir) | |
# backup the cfg | |
OmegaConf.save(cfg, os.path.join(self.log_dir, 'cfg.yaml')) | |
# backup the code state | |
if cfg.log_code_state: | |
dest_dir = os.path.join(self.logdir, 'code') | |
copytree(os.getcwd(), dest_dir, ignore=ignore_patterns(*cfg.patterns_to_ignore)) | |
# init logger which handles printing and logging mostly same things to the log file | |
self.print_logger = logging.getLogger('main') | |
self.print_logger.setLevel(logging.INFO) | |
msgfmt = '[%(levelname)s] %(asctime)s - %(name)s \n %(message)s' | |
datefmt = '%d %b %Y %H:%M:%S' | |
formatter = logging.Formatter(msgfmt, datefmt) | |
# stdout | |
sh = logging.StreamHandler() | |
sh.setLevel(logging.DEBUG) | |
sh.setFormatter(formatter) | |
self.print_logger.addHandler(sh) | |
# log file | |
fh = logging.FileHandler(os.path.join(self.log_dir, 'log.txt')) | |
fh.setLevel(logging.INFO) | |
fh.setFormatter(formatter) | |
self.print_logger.addHandler(fh) | |
self.print_logger.info(f'Saving logs and checkpoints @ {self.logdir}') | |
def log_param_num(self, model): | |
param_num = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
self.print_logger.info(f'The number of parameters: {param_num/1e+6:.3f} mil') | |
self.add_scalar('num_params', param_num, 0) | |
return param_num | |
def log_iter_loss(self, loss, iter, phase): | |
self.add_scalar(f'{phase}/loss_iter', loss, iter) | |
def log_epoch_loss(self, loss, epoch, phase): | |
self.add_scalar(f'{phase}/loss', loss, epoch) | |
self.print_logger.info(f'{phase} ({epoch}): loss {loss:.3f};') | |
def log_epoch_metrics(self, metrics_dict, epoch, phase): | |
for metric, val in metrics_dict.items(): | |
self.add_scalar(f'{phase}/{metric}', val, epoch) | |
metrics_dict = {k: round(v, 4) for k, v in metrics_dict.items()} | |
self.print_logger.info(f'{phase} ({epoch}) metrics: {metrics_dict};') | |
def log_test_metrics(self, metrics_dict, hparams_dict, best_epoch): | |
allowed_types = (int, float, str, bool, torch.Tensor) | |
hparams_dict = {k: v for k, v in hparams_dict.items() if isinstance(v, allowed_types)} | |
metrics_dict = {f'test/{k}': round(v, 4) for k, v in metrics_dict.items()} | |
exp, ssi, sei = summary.hparams(hparams_dict, metrics_dict) | |
self.file_writer.add_summary(exp) | |
self.file_writer.add_summary(ssi) | |
self.file_writer.add_summary(sei) | |
for k, v in metrics_dict.items(): | |
self.add_scalar(k, v, best_epoch) | |
self.print_logger.info(f'test ({best_epoch}) metrics: {metrics_dict};') | |
def log_best_model(self, model, loss, epoch, optimizer, metrics_dict): | |
model_name = model.__class__.__name__ | |
self.best_model_path = os.path.join(self.logdir, f'{model_name}-{self.start_time}.pt') | |
checkpoint = { | |
'loss': loss, | |
'metrics': metrics_dict, | |
'epoch': epoch, | |
'optimizer': optimizer.state_dict(), | |
'model': model.state_dict(), | |
} | |
torch.save(checkpoint, self.best_model_path) | |
self.print_logger.info(f'Saved model in {self.best_model_path}') | |