Skip to content

Commit 6d8858d

Browse files
author
Nicholas Leonard
committed
move dpnn modules
1 parent 31bd914 commit 6d8858d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+155
-4596
lines changed

AbstractRecurrent.lua

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ function AbstractRecurrent:getStepModule(step)
3333
end
3434

3535
function AbstractRecurrent:updateOutput(input)
36+
if self.train ~= false then
37+
self:recycle()
38+
end
3639
if self.zeroMask then
3740
-- where zeroMask = 1, the past is forgotten, that is, the output/gradOutput is zeroed
3841
local stepmodule = (self.train==false) and self.modules[1] or self:getStepModule(self.step)
@@ -189,6 +192,7 @@ end
189192

190193
function AbstractRecurrent:maskZero(v1)
191194
if not self.maskzero then
195+
assert(not torch.isTypeOf(self.modules[1], 'nn.AbstractRecurrent'), "Doesn't support zero-masking on nested AbstractRecurrent instances")
192196
self.maskzero = true
193197
local stepmodule = nn.MaskZero(self.modules[1], v1)
194198
self.sharedClones = {stepmodule}

AbstractSequencerCriterion.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ function AbstractSequencerCriterion:setZeroMask(zeroMask)
4646
end
4747
end
4848

49+
function AbstractSequencerCriterion:type(type, typecache)
50+
for key, clone in pairs(self.clones) do
51+
clone:type(type, typecache)
52+
end
53+
return parent.type(self, type, typecache)
54+
end

BatchNormalization.lua

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
local _ = require 'moses'
22
local BN, parent = nn.BatchNormalization, nn.Module
33

4-
local empty = _.clone(parent.dpnn_mediumEmpty)
5-
table.insert(empty, 'buffer')
6-
table.insert(empty, 'buffer2')
7-
table.insert(empty, 'centered')
8-
table.insert(empty, 'std')
9-
table.insert(empty, 'normalized')
10-
table.insert(empty, 'output')
11-
table.insert(empty, 'gradInput')
12-
BN.dpnn_mediumEmpty = empty
13-
144
-- for sharedClone
155
local params = _.clone(parent.dpnn_parameters)
166
table.insert(params, 'running_mean')

BinaryLogisticRegression.lua

Lines changed: 0 additions & 91 deletions
This file was deleted.

CAddTensorTable.lua

Lines changed: 0 additions & 43 deletions
This file was deleted.

CMakeLists.txt

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,15 @@ SET(luasrc
1717
AbstractSequencer.lua
1818
AbstractSequencerCriterion.lua
1919
BiSequencer.lua
20-
deprecated/BiSequencerLM.lua
2120
CopyGrad.lua
2221
Dropout.lua
2322
ExpandAs.lua
24-
deprecated/FastLSTM.lua
25-
deprecated/GRU.lua
26-
LinearNoBias.lua
2723
LookupTableMaskZero.lua
28-
deprecated/LSTM.lua
2924
MaskZero.lua
3025
MaskZeroCriterion.lua
3126
Module.lua
3227
Mufuru.lua
3328
NormStabilizer.lua
34-
Padding.lua
3529
Recurrence.lua
3630
RecurrentAttention.lua
3731
Recursor.lua
@@ -42,11 +36,8 @@ SET(luasrc
4236
SeqBLSTM.lua
4337
SeqGRU.lua
4438
SeqLSTM.lua
45-
deprecated/SeqLSTMP.lua
46-
deprecated/SeqReverseSequence.lua
4739
Sequencer.lua
4840
SequencerCriterion.lua
49-
ZeroGrad.lua
5041
test/bigtest.lua
5142
test/test.lua
5243
VariableLength.lua
@@ -60,52 +51,30 @@ SET(luasrc
6051
ArgMax.lua
6152
BatchNormalization.lua
6253
BinaryClassReward.lua
63-
BinaryLogisticRegression.lua
64-
CAddTensorTable.lua
6554
CategoricalEntropy.lua
66-
Clip.lua
67-
Collapse.lua
68-
Constant.lua
6955
Container.lua
70-
Convert.lua
7156
Criterion.lua
72-
Dictionary.lua
73-
FireModule.lua
74-
Inception.lua
75-
Kmeans.lua
7657
LookupTable.lua
77-
ModuleCriterion.lua
7858
NCECriterion.lua
7959
NCEModule.lua
80-
OneHot.lua
81-
PCAColorTransform.lua
8260
ParallelTable.lua
83-
PrintSize.lua
8461
Reinforce.lua
8562
ReinforceBernoulli.lua
8663
ReinforceCategorical.lua
8764
ReinforceGamma.lua
8865
ReinforceNormal.lua
8966
ReverseSequence.lua
9067
Sequential.lua
91-
Serial.lua
92-
SimpleColorTransform.lua
93-
SpatialBatchNormalization.lua
94-
SpatialBinaryConvolution.lua
95-
SpatialBinaryLogisticRegression.lua
96-
SpatialConvolution.lua
97-
SpatialConvolutionMM.lua
98-
SpatialFeatNormalization.lua
9968
SpatialGlimpse.lua
100-
SpatialMaxPooling.lua
101-
SpatialRegionDropout.lua
102-
SpatialUniformCrop.lua
10369
TotalDropout.lua
10470
VRClassReward.lua
105-
WhiteNoise.lua
106-
ZipTable.lua
107-
ZipTableOneToMany.lua
10871
ReverseUnreverse.lua
72+
deprecated/SeqLSTMP.lua
73+
deprecated/SeqReverseSequence.lua
74+
deprecated/BiSequencerLM.lua
75+
deprecated/FastLSTM.lua
76+
deprecated/GRU.lua
77+
deprecated/LSTM.lua
10978
)
11079

11180
ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")

Clip.lua

Lines changed: 0 additions & 35 deletions
This file was deleted.

Collapse.lua

Lines changed: 0 additions & 26 deletions
This file was deleted.

Constant.lua

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)