-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
227 lines (195 loc) · 7.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Training Code."""
import time
import torch
import torch.nn as nn
from torch.nn.functional import softmax
from network import get_network
from data import get_dataloaders
from optim import get_optimizer
from evaluate import Metric, evaluate_step
from collections import defaultdict, Counter
from utils import AverageMeter, mapping_func
from ema import EMA
def train_step(model,
ema,
criterion,
optimizer,
scheduler,
num_classes,
threshold,
unsupervised_weight,
amp_flag,
scaler,
X,
U,
N,
learning_status,
mapping,
device):
"""Train one epoch."""
global global_step
logs = defaultdict(AverageMeter)
metric = Metric()
cls_thresholds = torch.zeros(num_classes, device=device)
model.train()
for sample_x, sample_u in zip(X, U):
with torch.autocast(device_type='cuda',
dtype=torch.float16,
enabled=amp_flag):
# (weak, strong) augmented data
(xw, _), y, _ = sample_x
(uw, us), _, u_i = sample_u
inputs = torch.cat([xw, uw, us], dim=0)
outputs = model(inputs.to(device))
xw_pred, uw_pred, us_pred = torch.split(outputs,
[xw.shape[0],
uw.shape[0],
us.shape[0]])
# supervised loss
ls = criterion(xw_pred, y.to(device)).mean()
total_loss = ls
# compute a learning status
counter = Counter(learning_status)
# normalize the status
num_unused = counter[-1]
if num_unused != N:
max_counter = max([counter[c] for c in range(num_classes)])
if max_counter < num_unused:
# normalize with eq.11
sum_counter = sum([counter[c] for c in range(num_classes)])
denominator = max(max_counter, N - sum_counter)
else:
denominator = max_counter
# threshold per class
for c in range(num_classes):
beta = counter[c] / denominator
cls_thresholds[c] = mapping(beta) * threshold
# update the pseudo label
with torch.no_grad():
uw_prob = softmax(uw_pred, dim=1)
max_prob, hard_label = torch.max(uw_prob, dim=1)
over_threshold = max_prob > threshold
if over_threshold.any():
u_i = u_i.to(device)
sample_index = u_i[over_threshold].tolist()
pseudo_label = hard_label[over_threshold].tolist()
for i, l in zip(sample_index, pseudo_label):
learning_status[i] = l
# unsupervised loss
batch_threshold = torch.index_select(cls_thresholds, 0, hard_label)
indicator = max_prob > batch_threshold
lu = (criterion(us_pred, hard_label) * indicator).mean()
total_loss += lu * unsupervised_weight
# optimization
optimizer.zero_grad()
if amp_flag:
scaler.scale(total_loss).backward()
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
optimizer.step()
scheduler.step()
ema.update()
global_step += 1
# logging
metric.update_prediction(xw_pred, y)
logs['Ls'].update(ls.item())
logs['Mask'].update(torch.mean(indicator.float()).item())
if indicator.any():
logs['Lu'].update(lu.item())
Acc = metric.calc_accuracy()
Ls = logs['Ls'].avg
Lu = logs['Lu'].avg
Mask = logs['Mask'].avg
return Acc, Ls, Lu, Mask
def train_network(args):
"""Train a network."""
if args.wandb:
import wandb
global global_step
global_step = 0
device = torch.device('cuda')
if args.amp:
scaler = torch.cuda.amp.GradScaler()
# model
model = get_network(args.network, args.num_classes)
if args.mode == 'resume':
ckpt = torch.load(args.load_path, map_location='cpu')
model.load_state_dict(ckpt['state_dict'])
ema = EMA(model=model, decay=args.ema_decay)
ema.shadow.load_state_dict(ckpt['ema'])
start_iter = ckpt['iteration']
else:
start_iter = 0
ema = EMA(model=model, decay=args.ema_decay, device=device)
model.to(device)
# mapping function of beta
mapping = mapping_func(args.mapping)
# criterion
criterion = nn.CrossEntropyLoss(reduction='none')
# optimizer
optimizer, scheduler = get_optimizer(model=model,
lr=args.lr,
momentum=args.momentum,
nesterov=args.nesterov,
weight_decay=args.weight_decay,
iterations=args.iterations)
if args.mode == 'resume':
optimizer.load_state_dict(ckpt['optimizer'])
scheduler.load_state_dict(ckpt['scheduler'])
# labeled, unlabeled and test data
X, U, T = get_dataloaders(data=args.data,
num_X=args.num_X,
include_x_in_u=args.include_x_in_u,
augs=args.augs,
batch_size=args.batch_size,
mu=args.mu)
# Number of unlabeled data
N = len(U.dataset.indices)
if args.mode == 'resume' and 'learning_status' in ckpt:
learning_status = ckpt['learning_status']
else:
learning_status = [-1] * N
n_iter = 1024
for epoch in range(start_iter//n_iter, n_iter):
Acc, Ls, Lu, Mask = train_step(model=model,
ema=ema,
X=X,
U=U,
N=N,
optimizer=optimizer,
scheduler=scheduler,
num_classes=args.num_classes,
threshold=args.threshold,
unsupervised_weight=args.lu_weight,
amp_flag=args.amp,
scaler=scaler,
learning_status=learning_status,
mapping=mapping,
criterion=criterion,
device=device)
test_Acc = evaluate_step(ema.shadow, T, device)
print((f"{time.ctime()}: "
f"Iteration: [{global_step}/{args.iterations}], "
f"Ls: {Ls:1.4f}, Lu: {Lu:1.4f}, Mask: {Mask:1.4f}, "
f"Acc(train/test): [{Acc:1.4f}/{test_Acc:1.4f}]"))
check_point = {
'state_dict': model.state_dict(),
'ema': ema.shadow.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'args': args,
'learning_status': learning_status,
'iteration': global_step
}
torch.save(check_point, args.save_path / 'ckpt.pth')
if epoch % 10 == 0:
torch.save(check_point, args.save_path / f'ckpt_{global_step}.pth')
if args.wandb:
wandb.log(data={'Ls': Ls,
'Lu': Lu,
'Mask': Mask,
'Train Acc': Acc,
'Test Acc': test_Acc},
step=global_step)