File size: 3,756 Bytes
a84a65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import logging
import os
import time
from shutil import copytree, ignore_patterns

import torch
from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter, summary


class LoggerWithTBoard(SummaryWriter):

    def __init__(self, cfg):
        # current time stamp and experiment log directory
        self.start_time = time.strftime('%y-%m-%dT%H-%M-%S', time.localtime())
        self.logdir = os.path.join(cfg.logdir, self.start_time)
        # init tboard
        super().__init__(self.logdir)
        # backup the cfg
        OmegaConf.save(cfg, os.path.join(self.log_dir, 'cfg.yaml'))
        # backup the code state
        if cfg.log_code_state:
            dest_dir = os.path.join(self.logdir, 'code')
            copytree(os.getcwd(), dest_dir, ignore=ignore_patterns(*cfg.patterns_to_ignore))

        # init logger which handles printing and logging mostly same things to the log file
        self.print_logger = logging.getLogger('main')
        self.print_logger.setLevel(logging.INFO)
        msgfmt = '[%(levelname)s] %(asctime)s - %(name)s \n    %(message)s'
        datefmt = '%d %b %Y %H:%M:%S'
        formatter = logging.Formatter(msgfmt, datefmt)
        # stdout
        sh = logging.StreamHandler()
        sh.setLevel(logging.DEBUG)
        sh.setFormatter(formatter)
        self.print_logger.addHandler(sh)
        # log file
        fh = logging.FileHandler(os.path.join(self.log_dir, 'log.txt'))
        fh.setLevel(logging.INFO)
        fh.setFormatter(formatter)
        self.print_logger.addHandler(fh)

        self.print_logger.info(f'Saving logs and checkpoints @ {self.logdir}')

    def log_param_num(self, model):
        param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
        self.print_logger.info(f'The number of parameters: {param_num/1e+6:.3f} mil')
        self.add_scalar('num_params', param_num, 0)
        return param_num

    def log_iter_loss(self, loss, iter, phase):
        self.add_scalar(f'{phase}/loss_iter', loss, iter)

    def log_epoch_loss(self, loss, epoch, phase):
        self.add_scalar(f'{phase}/loss', loss, epoch)
        self.print_logger.info(f'{phase} ({epoch}): loss {loss:.3f};')

    def log_epoch_metrics(self, metrics_dict, epoch, phase):
        for metric, val in metrics_dict.items():
            self.add_scalar(f'{phase}/{metric}', val, epoch)
        metrics_dict = {k: round(v, 4) for k, v in metrics_dict.items()}
        self.print_logger.info(f'{phase} ({epoch}) metrics: {metrics_dict};')

    def log_test_metrics(self, metrics_dict, hparams_dict, best_epoch):
        allowed_types = (int, float, str, bool, torch.Tensor)
        hparams_dict = {k: v for k, v in hparams_dict.items() if isinstance(v, allowed_types)}
        metrics_dict = {f'test/{k}': round(v, 4) for k, v in metrics_dict.items()}
        exp, ssi, sei = summary.hparams(hparams_dict, metrics_dict)
        self.file_writer.add_summary(exp)
        self.file_writer.add_summary(ssi)
        self.file_writer.add_summary(sei)
        for k, v in metrics_dict.items():
            self.add_scalar(k, v, best_epoch)
        self.print_logger.info(f'test ({best_epoch}) metrics: {metrics_dict};')

    def log_best_model(self, model, loss, epoch, optimizer, metrics_dict):
        model_name = model.__class__.__name__
        self.best_model_path = os.path.join(self.logdir, f'{model_name}-{self.start_time}.pt')
        checkpoint = {
            'loss': loss,
            'metrics': metrics_dict,
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'model': model.state_dict(),
        }
        torch.save(checkpoint, self.best_model_path)
        self.print_logger.info(f'Saved model in {self.best_model_path}')