forked from willwhitney/dc-ign
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestf.lua
195 lines (156 loc) · 5.44 KB
/
testf.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
-- test function
function testf(saveAll)
-- in case it didn't already exist
os.execute('mkdir ' .. 'tmp')
-- local vars
local time = sys.clock()
-- test over given dataset
print('<trainer> on testing Set:')
reconstruction = 0
local lowerbound = 0
for t = 1, num_test_batches do
collectgarbage()
-- create mini batch
local raw_inputs = load_batch(t, MODE_TEST)
local targets = raw_inputs
inputs = raw_inputs:cuda()
-- disp progress
xlua.progress(t, num_test_batches)
-- test samples
local preds = model:forward(inputs)
local f = preds
local target = targets
local err = - criterion:forward(f, target:cuda())
local encoder_output = model:get(1).output
local KLDerr = KLD:forward(encoder_output, target)
lowerbound = lowerbound + err + KLDerr
preds = preds:float()
reconstruction = reconstruction + torch.sum(torch.pow(preds-targets,2))
if saveAll then
torch.save('tmp/preds' .. t, preds)
else
if t == 1 then
torch.save('tmp/preds' .. t, preds)
end
end
end
-- timing
time = sys.clock() - time
time = time / num_test_batches
print("<trainer> time to test 1 sample = " .. (time*1000) .. 'ms')
-- print confusion matrix
reconstruction = reconstruction / (bsize * num_test_batches * 3 * 150 * 150)
print('mean MSE error (test set)', reconstruction)
testLogger:add{['% mean class accuracy (test set)'] = reconstruction}
reconstruction = 0
return lowerbound
end
function testf_cifar(saveAll,opt)
-- in case it didn't already exist
os.execute('mkdir ' .. 'tmp')
-- local vars
local time = sys.clock()
-- test over given dataset
print('<trainer> on testing Set:')
reconstruction = 0
local lowerbound = 0
for t = 1,testData.data:size()[1],opt.batchSize do
collectgarbage()
-- create mini batch
local raw_inputs = torch.zeros(opt.batchSize, 3* 32 *32)
local cnt = 1
for ii = t,math.min(t+opt.batchSize-1,testData.data:size()[1]) do
raw_inputs[cnt] = testData.data[ii]
cnt = cnt + 1
end
raw_inputs = raw_inputs/255.0
local targets = raw_inputs
inputs = raw_inputs:cuda()
-- disp progress
xlua.progress(t, testData.data:size()[1])
-- test samples
local preds = model:forward(inputs)
local f = preds
local target = targets
local err = - criterion:forward(f, target:cuda())
local encoder_output = model:get(1).output
local KLDerr = KLD:forward(encoder_output, target)
lowerbound = lowerbound + err + KLDerr
preds = preds:float()
targets = targets:float()
reconstruction = reconstruction + torch.sum(torch.pow(preds-targets,2))
if saveAll then
torch.save('tmp/preds' .. t, preds)
else
if t == 1 then
torch.save('tmp/preds' .. t, preds)
end
end
end
-- timing
time = sys.clock() - time
time = time / num_test_batches
print("<trainer> time to test 1 sample = " .. (time*1000) .. 'ms')
-- print confusion matrix
reconstruction = reconstruction / (bsize * num_test_batches * 3 * 150 * 150)
print('mean MSE error (test set)', reconstruction)
testLogger:add{['% mean class accuracy (test set)'] = reconstruction}
reconstruction = 0
return lowerbound
end
-- test function for monovariant tests
function testf_MV(saveAll)
-- in case it didn't already exist
os.execute('mkdir -p' .. 'tmp')
-- local vars
local time = sys.clock()
-- test over given dataset
print('<trainer> on testing Set:')
reconstruction = 0
local lowerbound = 0
-- turn the clamps off so we get full batch outputs
for clampIndex = 1, #clamps do
clamps[clampIndex].active = false
gradFilters[clampIndex].active = false
end
for _, dataset_name in pairs({"AZ_VARIED", "EL_VARIED", "LIGHT_AZ_VARIED", "SHAPE_VARIED"}) do
local save_dir = 'tmp' .. '/' .. opt.save .. '/' .. dataset_name .. '/epoch_' .. epoch
os.execute('mkdir -p ' .. save_dir)
for t = 1, opt.num_test_batches_per_type do
collectgarbage()
-- create mini batch
local raw_inputs = load_mv_batch(t, dataset_name, MODE_TEST)
local targets = raw_inputs
inputs = raw_inputs:cuda()
-- disp progress
xlua.progress(t, opt.num_test_batches_per_type)
-- test samples
local preds = model:forward(inputs)
local f = preds
local target = targets
local err = - criterion:forward(f, target:cuda())
local encoder_output = model:get(1).output
local KLDerr = KLD:forward(encoder_output, target)
lowerbound = lowerbound + err + KLDerr
preds = preds:float()
reconstruction = reconstruction + torch.sum(torch.pow(preds-targets,2))
if saveAll then
torch.save(save_dir..'/preds' .. t, preds)
else
if t < 10 then
torch.save(save_dir..'/preds' .. t, preds)
end
end
end
end
-- timing
time = sys.clock() - time
time = time / opt.num_test_batches
print("<trainer> time to test 1 sample = " .. (time*1000) .. 'ms')
-- print confusion matrix
reconstruction = reconstruction / (opt.bsize * opt.num_test_batches * 3 * 150 * 150)
print('mean MSE error (test set)', reconstruction)
testLogger:add{['% mean class accuracy (test set)'] = reconstruction}
reconstruction = 0
return lowerbound
end