-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
223 lines (156 loc) · 8.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from keras.models import load_model
from keras import backend as K
from models import mlp
import ipynb.fs
import argparse
from keras.utils import to_categorical
from sklearn.model_selection import KFold
from keras.callbacks import ModelCheckpoint
from sklearn.metrics import confusion_matrix
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import csv, os
import pickle as pkl
import numpy as np
import pandas as pd
from ipynb.fs.defs.Make_test_table import create_test_table
workdir = '/home/users/adoyle/tadpole/data/'
in_df = pd.read_csv(workdir + 'KNN_50_cleaned.csv', index_col=0)
ref_df = pd.read_csv(workdir + 'KNN_50_cleaned_ref.csv', index_col=0)
def plot_graphs(hist, results_dir, fold_num):
epoch_num = range(len(hist.history['future_diagnosis_acc']))
plt.clf()
plt.plot(epoch_num, hist.history['ventricle_volume_mean_squared_error'], label='Ventricle Vol. MSE')
plt.plot(epoch_num, hist.history['val_ventricle_volume_mean_squared_error'], label='Validation Ventricle Vol. MSE')
plt.plot(epoch_num, hist.history['future_diagnosis_acc'], label='Future Diagnosis Accuracy')
plt.plot(epoch_num, hist.history['val_future_diagnosis_acc'], label='Validation Future Diagnosis Accuracy')
plt.legend(shadow=True)
plt.xlabel("Training Epoch Number")
plt.ylabel("Metric Value")
plt.savefig(results_dir + 'training_metrics_fold' + str(fold_num) + '.png', bbox_inches='tight')
plt.close()
def test_d2(model, all_rids, results_dir):
(rids), (x_t, y_t, delta_t), (dx_next, adas_next, ventricle_next), (dx_change) = load_data_samples(all_rids, mode='train')
predictions = model.predict([x_t, y_t, delta_t])
dx_probabilities = predictions[0:3, :]
adas_predictions = predictions[3, :]
ventricle_predictions = predictions[4, :]
dx_predictions = np.argmax(dx_probabilities, axis=-1)
confusion = confusion_matrix(np.asarray(dx_predictions, dtype='uint8'), np.asarray(dx_next, dtype='uint8'))
print('Confusion matrix for D2 diagnosis predictions:')
print(confusion)
# un-normalize regression predictions
def test_future(model, all_rids, results_dir):
(rids), (x_t, y_t, delta_t) = load_data_samples(all_rids, mode='test') # in test mode, delta_t is the time until the last timepoint
with open(results_dir + 'future_predictions.csv', 'w') as prediction_file:
prediction_writer = csv.writer(prediction_file, lineterminator='\n')
prediction_writer.writerow(['ID', 'Months', 'P(Control)', 'P(MCI)', 'P(ALZ)', 'ADAS-13', 'Ventricular Volume'])
# TODO: stuff from here probably doesnt work
n_months = 60
for rid in all_rids:
x = x_t[x_t[0] == rid]
y = y_t[x_t[0] == rid]
delta_t = delta_t[x_t[0] == rid]
n_timepoints = x.shape[0]
n_outputs = y.shape[1] + 2
for t, months_forward in enumerate(range(1, n_months)):
predictions = np.zeros((n_timepoints, n_months, n_outputs), dtype='float32')
time_forward = delta_t + months_forward
dx, adas, vent = model.predict([x, y, time_forward])
predictions[:, t, 0:3] = dx
predictions[:, t, 3] = adas
predictions[:, t, 4] = vent
prediction_writer.writerow([rid, months_forward, np.mean(predictions[:, t, 0]), np.mean(predictions[:, t, 1]), np.mean(predictions[:, t, 2]), np.mean(predictions[:, t, 3]), np.mean(predictions[:, t, 4])])
return
def load_data_samples(rids, mode='train'):
table = create_test_table(in_df, ref_df, rids, mode=mode).dropna()
# headers = table.columns.values
#
# print('Features:', headers[:-7])
# print('Outputs:', headers[-7:-4])
# print('Prediction targets:', headers[-3:])
# print('Timestep:', headers[-4])
rids = table.iloc[:, 0]
x_t = table.iloc[:, 1:-7]
y_t = table.iloc[:, -7:-4]
delta_t = table.iloc[:, -4]
# current diagnosis is an input feature
dx = to_categorical(y_t.iloc[:, 0] - 1, num_classes=3)
y_t_categorical = np.hstack((dx, y_t.iloc[:, 1:]))
if mode == 'train':
y_t_next = table.iloc[:, -3:]
# prediction targets
dx_next = to_categorical(y_t_next.iloc[:, 0] - 1, num_classes=3)
adas_next = y_t_next.iloc[:, 1]
ventricle_next = y_t_next.iloc[:, 2]
# determine what timepoints have a change in diagnosis
dx_change = np.not_equal(dx, dx_next)
return (rids), (x_t, y_t_categorical, delta_t), (dx_next, adas_next, ventricle_next), (dx_change)
elif mode == 'test':
return (rids), (x_t, y_t_categorical, delta_t)
if __name__ == "__main__":
print('It\'s not the size that counts, it\'s the connections')
all_rids = ref_df['RID'].unique()
print('There are', len(all_rids), 'subjects total in the ADNI dataset')
try:
experiment_number = pkl.load(open(workdir + 'experiment_number.pkl', 'rb'))
experiment_number += 1
except:
print('Couldnt find the file to load experiment number')
experiment_number = 0
print('This is experiment number:', experiment_number)
results_dir = workdir + '/experiment-' + str(experiment_number) + '/'
os.makedirs(results_dir)
pkl.dump(experiment_number, open(workdir + 'experiment_number.pkl', 'wb'))
kf = KFold(n_splits=5, shuffle=True)
for k, (train_rids, test_rids) in enumerate(kf.split(all_rids)):
testing_table = create_test_table(in_df, ref_df, test_rids, mode='train').dropna()
(rids_train),\
(x_t_train, y_t_train, delta_t_train),\
(dx_next_train, adas_next_train, ventricle_next_train),\
(dx_change_train) = load_data_samples(train_rids, mode='train')
(rids_test),\
(x_t_test, y_t_test, delta_t_test),\
(dx_next_test, adas_next_test, ventricle_next_test),\
(dx_change_test) = load_data_samples(test_rids, mode='train')
n_healthy = np.sum(dx_next_train[:, 0])
n_mci = np.sum(dx_next_train[:, 1])
n_alzheimers = np.sum(dx_next_train[:, 2])
n_total = dx_next_train.shape[0]
print('Distribution of prediction targets this fold (training):')
print(n_healthy, 'healthy')
print(n_mci, 'mild cognitive impairment')
print(n_alzheimers, 'Alzheimer\'s')
print(n_total, 'total')
healthy_weight = (n_total - n_healthy) / n_total
mci_weight = (n_total - n_mci) / n_total
alzheimers_weight = (n_total - n_alzheimers) / n_total
print('Class weight (healthy, MCI, Alzheimer\'s):', healthy_weight, mci_weight, alzheimers_weight)
diagnosis_changes = np.sum(dx_change_train)
print('Proportion of samples where diagnosis changes between timepoints:', diagnosis_changes / n_total)
# initialize model and train it
model = mlp(x_t_train.shape[-1])
model.summary()
model_checkpoint = ModelCheckpoint(results_dir + "best_weights_fold_" + str(k) + ".hdf5",
monitor="val_future_diagnosis_acc",
save_best_only=True)
model.compile(optimizer='adam',
loss={'future_diagnosis': 'categorical_crossentropy',
'ventricle_volume': 'mean_squared_error',
'as_cog': 'mean_squared_error'},
loss_weights={'future_diagnosis': 0.00001, 'ventricle_volume': 1, 'as_cog': 0.0001},
metrics={'future_diagnosis': 'accuracy',
'ventricle_volume': 'mean_squared_error',
'as_cog': 'mean_squared_error'}
)
print(model.metrics_names)
print(model.metrics)
hist = model.fit([x_t_train, y_t_train, delta_t_train], [dx_next_train, adas_next_train, ventricle_next_train], epochs=200, validation_data=([x_t_test, y_t_test, delta_t_test], [dx_next_test, adas_next_test, ventricle_next_test]), callbacks=[model_checkpoint])
model.load_weights(results_dir + "best_weights_fold_" + str(k) + ".hdf5")
model.save(results_dir + 'best_tadpole_model' + str(k) + '.hdf5')
plot_graphs(hist, results_dir, k)
K.clear_session()
model = load_model(results_dir + 'best_tadpole_model0.hdf5')
test_d2(model, all_rids, results_dir)
test_future(model, all_rids, results_dir)