-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathtb_logger.py
86 lines (77 loc) · 3.41 KB
/
tb_logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from utils.conf import base_path
import os
from argparse import Namespace
from typing import Dict, Any
import numpy as np
class TensorboardLogger:
def __init__(self, args: Namespace, setting: str,
stash: Dict[Any, str]=None) -> None:
from torch.utils.tensorboard import SummaryWriter
self.settings = [setting]
if setting == 'class-il':
self.settings.append('task-il')
self.loggers = {}
self.name = stash['model_name']
for a_setting in self.settings:
self.loggers[a_setting] = SummaryWriter(
os.path.join(base_path(), 'tensorboard_runs', a_setting, self.name),
purge_step=stash['task_idx'] * args.n_epochs + stash['epoch_idx']+1)
config_text = ', '.join(
["%s=%s" % (name, getattr(args, name)) for name in args.__dir__()
if not name.startswith('_')])
for a_logger in self.loggers.values():
a_logger.add_text('config', config_text)
def get_name(self) -> str:
"""
:return: the name of the model
"""
return self.name
def log_accuracy(self, all_accs: np.ndarray, all_mean_accs: np.ndarray,
args: Namespace, task_number: int) -> None:
"""
Logs the current accuracy value for each task.
:param all_accs: the accuracies (class-il, task-il) for each task
:param all_mean_accs: the mean accuracies for (class-il, task-il)
:param args: the arguments of the run
:param task_number: the task index
"""
mean_acc_common, mean_acc_task_il = all_mean_accs
for setting, a_logger in self.loggers.items():
mean_acc = mean_acc_task_il\
if setting == 'task-il' else mean_acc_common
index = 1 if setting == 'task-il' else 0
accs = [all_accs[index][kk] for kk in range(len(all_accs[0]))]
for kk, acc in enumerate(accs):
a_logger.add_scalar('acc_task%02d' % (kk + 1), acc,
task_number * args.n_epochs)
a_logger.add_scalar('acc_mean', mean_acc, task_number * args.n_epochs)
def log_loss(self, loss: float, args: Namespace, epoch: int,
task_number: int, iteration: int) -> None:
"""
Logs the loss value at each iteration.
:param loss: the loss value
:param args: the arguments of the run
:param epoch: the epoch index
:param task_number: the task index
:param iteration: the current iteration
"""
for a_logger in self.loggers.values():
a_logger.add_scalar('loss', loss, task_number * args.n_epochs + epoch)
def log_loss_gcl(self, loss: float, iteration: int) -> None:
"""
Logs the loss value at each iteration.
:param loss: the loss value
:param iteration: the current iteration
"""
for a_logger in self.loggers.values():
a_logger.add_scalar('loss', loss, iteration)
def close(self) -> None:
"""
At the end of the execution, closes the logger.
"""
for a_logger in self.loggers.values():
a_logger.close()