-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathcontinual_training.py
76 lines (63 loc) · 2.47 KB
/
continual_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from datasets import get_gcl_dataset
from models import get_model
from utils.status import progress_bar
from utils.tb_logger import *
from utils.status import create_fake_stash
from models.utils.continual_model import ContinualModel
from argparse import Namespace
def evaluate(model: ContinualModel, dataset) -> float:
"""
Evaluates the final accuracy of the model.
:param model: the model to be evaluated
:param dataset: the GCL dataset at hand
:return: a float value that indicates the accuracy
"""
model.net.eval()
correct, total = 0, 0
while not dataset.test_over:
inputs, labels = dataset.get_test_data()
inputs, labels = inputs.to(model.device), labels.to(model.device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
correct += torch.sum(predicted == labels).item()
total += labels.shape[0]
acc = correct / total * 100
return acc
def train(args: Namespace):
"""
The training process, including evaluations and loggers.
:param model: the module to be trained
:param dataset: the continual dataset at hand
:param args: the arguments of the current execution
"""
if args.csv_log:
from utils.loggers import CsvLogger
dataset = get_gcl_dataset(args)
backbone = dataset.get_backbone()
loss = dataset.get_loss()
model = get_model(args, backbone, loss, dataset.get_transform())
model.net.to(model.device)
model_stash = create_fake_stash(model, args)
if args.csv_log:
csv_logger = CsvLogger(dataset.SETTING, dataset.NAME, model.NAME)
model.net.train()
epoch, i = 0, 0
while not dataset.train_over:
inputs, labels, not_aug_inputs = dataset.get_train_data()
inputs, labels = inputs.to(model.device), labels.to(model.device)
not_aug_inputs = not_aug_inputs.to(model.device)
loss = model.observe(inputs, labels, not_aug_inputs)
progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss)
i += 1
if model.NAME == 'joint_gcl':
model.end_task(dataset)
acc = evaluate(model, dataset)
print('Accuracy:', acc)
if args.csv_log:
csv_logger.log(acc)
csv_logger.write(vars(args))