forked from willwhitney/dc-ign
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.lua
76 lines (66 loc) · 2.24 KB
/
utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function load_batch(id, mode)
res = torch.load('DATASET/th_' .. mode .. '/batch' .. id)
return res
end
function load_mv_batch(id, dataset_name, mode)
dataset = '/th_' .. dataset_name .. '/' .. mode .. '/batch' .. id
print("loading: " .. dataset)
return torch.load(opt.datasetdir .. dataset)
end
function load_random_mv_batch(mode)
local variation_type = math.random(4)
local variation_name = ""
if variation_type == 1 then
variation_name = "AZ_VARIED"
elseif variation_type == 2 then
variation_name = "EL_VARIED"
elseif variation_type == 3 then
variation_name = "LIGHT_AZ_VARIED"
elseif variation_type == 4 then
variation_name = "SHAPE_VARIED"
end
id = 1
if mode == MODE_TRAINING then
id = math.random(opt.num_train_batches_per_type)
elseif mode == MODE_TEST then
id = math.random(opt.num_test_batches_per_type)
end
return load_mv_batch(id, variation_name, mode), variation_type
end
-- has a bias towards shape samples
function load_random_mv_shape_bias_batch(mode)
local variation_type = math.random(4 + opt.shape_bias_amount)
local variation_name = ""
if variation_type == 1 then
variation_name = "AZ_VARIED"
elseif variation_type == 2 then
variation_name = "EL_VARIED"
elseif variation_type == 3 then
variation_name = "LIGHT_AZ_VARIED"
else
variation_name = "SHAPE_VARIED"
variation_type = 4
end
id = 1
if mode == MODE_TRAINING then
id = math.random(opt.num_train_batches_per_type)
elseif mode == MODE_TEST then
id = math.random(opt.num_test_batches_per_type)
end
return load_mv_batch(id, variation_name, mode), variation_type
end
function getLowerbound(data)
local lowerbound = 0
N_data = num_test_batches
for i = 1, N_data, batchSize do
local batch = data[{{i,i+batchSize-1},{}}]
local f = model:forward(batch)
local target = target or batch.new()
target:resizeAs(f):copy(batch)
local err = - criterion:forward(f, target)
local encoder_output = model:get(1).output
local KLDerr = KLD:forward(encoder_output, target)
lowerbound = lowerbound + err + KLDerr
end
return lowerbound
end