-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_lstm.py
42 lines (37 loc) · 1.51 KB
/
train_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
from torch import nn
from torch.optim import Adam
from lstm import LSTM
from music_dataloader import create_split_loaders
from torch_utils import setup_device
from train_eval_utils import *
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--n-layers', type=int, dest='n_layers', default=1)
parser.add_argument('-l', '--layer-size', type=int, dest='layer_size', default=100)
parser.add_argument('-lr', '--learning-rate', type=float, dest='lr', default=0.001)
parser.add_argument('-m', '--model-name', type=str, dest='model_name')
parser.add_argument('-n', '--n-epochs', type=int, dest='n_epochs', default=10)
parser.add_argument('-s', '--seq-lenght', type=int, dest='seq_length', default=100)
parser.add_argument('-u', '--update-hist', type=int, dest='update_hist', default=25)
parser.add_argument('-v', '--val_every', type=int, dest='val_every', default=1000)
args = parser.parse_args()
computing_device = setup_device()
train_loader, val_loader, test_loader, dictionary = create_split_loaders(args.seq_length)
lstm = LSTM(len(dictionary),
args.layer_size,
len(dictionary),
computing_device,
n_layers=args.n_layers)
lstm.to(computing_device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(lstm.parameters(), lr=args.lr)
fit_rnn(lstm,
criterion,
optimizer,
train_loader,
val_loader,
args.n_epochs,
args.model_name,
computing_device,
val_every=args.val_every,
update_hist=args.update_hist)