-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_eval_utils.py
83 lines (69 loc) · 3.12 KB
/
train_eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from datetime import datetime as dt
import itertools
import numpy as np
import os
import pickle
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from lstm import LSTM
from rnn import RNN
from music_dataloader import create_split_loaders
from torch_utils import setup_device
def fit_rnn(model, criterion, optimizer, train_loader, val_loader, n_epochs, model_name, computing_device,
seq_length=100, chkpt_every=25, update_hist=50, val_every=35000):
train_losses = dict()
val_losses = dict()
total_seen = 0
start_time = dt.now()
min_val_loss = torch.tensor(np.inf)
# Make the directory to save the model and the losses
if not os.path.exists('models'):
os.mkdir('models')
if not os.path.exists('train-stats'):
os.mkdir('train-stats')
model_save_path = os.path.join('models', model_name + '.pt')
train_save_path = os.path.join('train-stats', model_name + '_train.pkl')
val_save_path = os.path.join('train-stats', model_name + '_val.pkl')
for epoch in np.arange(n_epochs):
train_losses[epoch] = []
for i, (x, y) in enumerate(train_loader):
model.train()
optimizer.zero_grad()
x, y = x.to(computing_device), y.to(computing_device)
g = torch.squeeze(model(torch.unsqueeze(x, 0)))
loss = criterion(g, y)
loss.backward()
optimizer.step()
train_losses[epoch].append(loss.detach().cpu().numpy())
total_seen += 1
# Report training stats
avg_loss = np.mean(train_losses[epoch][-update_hist:])
time_delta = dt.now() - start_time
update_str = '[TIME ELAPSED]: {0} [EPOCH {1}]: Avg. loss for last {2} minibatches: {3:0.5f}'
print(update_str.format(str(time_delta),epoch + 1, chkpt_every, avg_loss), end='\r')
# Save the model and the training and validation losses
if not total_seen % val_every:
# Validate the model
val_losses[total_seen] = evaluate_model(model, val_loader, criterion, computing_device,
start_time)
with open(train_save_path, 'wb') as f:
pickle.dump(train_losses, f)
with open(val_save_path, 'wb') as f:
pickle.dump(val_losses, f)
if val_losses[total_seen] < min_val_loss:
torch.save(model.state_dict(), model_save_path)
min_val_loss = val_losses[total_seen]
print(f'\n[EPOCH {epoch + 1}]: Avg. loss for epoch: {np.mean(train_losses[epoch])}\n')
def evaluate_model(model, loader, criterion, computing_device, start_time):
model.eval()
val_losses = []
T = len(loader)
print(' ' * 120, end='\r')
for i, (x, y) in enumerate(loader):
x = x.to(computing_device)
y = y.to(computing_device)
val_losses.append(criterion(torch.squeeze(model(torch.unsqueeze(x, 0))), y))
print(f'Validating model: {i}/{T}', end='\r')
return torch.mean(torch.tensor(val_losses))