Skip to content

Commit b99fd45

Browse files
authored
Merge pull request #484 from cogmission/potential_radius_fix
Added fix for potential radius
2 parents 820eb63 + 831607c commit b99fd45

File tree

7 files changed

+85
-12
lines changed

7 files changed

+85
-12
lines changed

src/main/java/org/numenta/nupic/Connections.java

+28-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@
4848
import org.numenta.nupic.model.ProximalDendrite;
4949
import org.numenta.nupic.model.Segment;
5050
import org.numenta.nupic.model.Synapse;
51+
import org.numenta.nupic.network.Persistence;
52+
import org.numenta.nupic.network.PersistenceAPI;
53+
import org.numenta.nupic.serialize.SerialConfig;
5154
import org.numenta.nupic.util.AbstractSparseBinaryMatrix;
55+
import org.numenta.nupic.util.ArrayUtils;
5256
import org.numenta.nupic.util.FlatMatrix;
5357
import org.numenta.nupic.util.SparseMatrix;
5458
import org.numenta.nupic.util.SparseObjectMatrix;
@@ -69,6 +73,10 @@ public class Connections implements Persistable {
6973
private static final double EPSILON = 0.00001;
7074

7175
/////////////////////////////////////// Spatial Pooler Vars ///////////////////////////////////////////
76+
/** <b>WARNING:</b> potentialRadius **must** be set to
77+
* the inputWidth if using "globalInhibition" and if not
78+
* using the Network API (which sets this automatically)
79+
*/
7280
private int potentialRadius = 16;
7381
private double potentialPct = 0.5;
7482
private boolean globalInhibition = false;
@@ -242,12 +250,25 @@ public class Connections implements Persistable {
242250
*/
243251
public Connections() {}
244252

253+
/**
254+
* Returns a deep copy of this {@code Connections} object.
255+
* @return a deep copy of this {@code Connections}
256+
*/
257+
public Connections copy() {
258+
PersistenceAPI api = Persistence.get(new SerialConfig());
259+
byte[] myBytes = api.serializer().serialize(this);
260+
return api.serializer().deSerialize(myBytes);
261+
}
262+
245263
/**
246264
* Sets the derived values of the {@link SpatialPooler}'s initialization.
247265
*/
248266
public void doSpatialPoolerPostInit() {
249267
synPermBelowStimulusInc = synPermConnected / 10.0;
250268
synPermTrimThreshold = synPermActiveInc / 2.0;
269+
if(potentialRadius == -1) {
270+
potentialRadius = ArrayUtils.product(inputDimensions);
271+
}
251272
}
252273

253274
/////////////////////////////////////////
@@ -480,6 +501,11 @@ public void setNumColumns(int n) {
480501
* parameter defines a square (or hyper square) area: a
481502
* column will have a max square potential pool with
482503
* sides of length 2 * potentialRadius + 1.
504+
*
505+
* <b>WARNING:</b> potentialRadius **must** be set to
506+
* the inputWidth if using "globalInhibition" and if not
507+
* using the Network API (which sets this automatically)
508+
*
483509
*
484510
* @param potentialRadius
485511
*/
@@ -489,11 +515,12 @@ public void setPotentialRadius(int potentialRadius) {
489515

490516
/**
491517
* Returns the configured potential radius
518+
*
492519
* @return the configured potential radius
493520
* @see setPotentialRadius
494521
*/
495522
public int getPotentialRadius() {
496-
return Math.min(numInputs, potentialRadius);
523+
return potentialRadius;
497524
}
498525

499526
/**

src/main/java/org/numenta/nupic/Parameters.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public class Parameters implements Persistable {
9797
//////////// Spatial Pooler Parameters ///////////
9898
Map<KEY, Object> defaultSpatialParams = new ParametersMap();
9999
defaultSpatialParams.put(KEY.INPUT_DIMENSIONS, new int[]{64});
100-
defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, 16);
100+
defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, -1);
101101
defaultSpatialParams.put(KEY.POTENTIAL_PCT, 0.5);
102102
defaultSpatialParams.put(KEY.GLOBAL_INHIBITION, false);
103103
defaultSpatialParams.put(KEY.INHIBITION_RADIUS, 0);
@@ -225,6 +225,10 @@ public static enum KEY {
225225

226226
/////////// Spatial Pooler Parameters ///////////
227227
INPUT_DIMENSIONS("inputDimensions", int[].class),
228+
/** <b>WARNING:</b> potentialRadius **must** be set to
229+
* the inputWidth if using "globalInhibition" and if not
230+
* using the Network API (which sets this automatically)
231+
*/
228232
POTENTIAL_RADIUS("potentialRadius", Integer.class),
229233
POTENTIAL_PCT("potentialPct", Double.class), //TODO add range here?
230234
GLOBAL_INHIBITION("globalInhibition", Boolean.class),
@@ -770,6 +774,11 @@ public void setInputDimensions(int[] inputDimensions) {
770774
* parameter defines a square (or hyper square) area: a
771775
* column will have a max square potential pool with
772776
* sides of length 2 * potentialRadius + 1.
777+
*
778+
* <b>WARNING:</b> potentialRadius **must** be set to
779+
* the inputWidth if using "globalInhibition" and if not
780+
* using the Network API (which sets this automatically)
781+
*
773782
*
774783
* @param potentialRadius
775784
*/

src/main/java/org/numenta/nupic/network/Layer.java

+2-4
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,7 @@ public Layer<T> close() {
495495
params.setInputDimensions(upstreamDims);
496496
connections.setInputDimensions(upstreamDims);
497497
} else if(parentRegion != null && parentNetwork != null
498-
&& parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null
499-
&& spatialPooler != null) {
500-
498+
&& parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null && spatialPooler != null) {
501499
Layer<?> curr = this;
502500
while((curr = curr.getPrevious()) != null) {
503501
if(curr.getEncoder() != null) {
@@ -692,7 +690,7 @@ public Subscription subscribe(final Observer<Inference> subscriber) {
692690

693691
return createSubscription(subscriber);
694692
}
695-
693+
696694
/**
697695
* Allows the user to define the {@link Connections} object data structure
698696
* to use. Or possibly to share connections between two {@code Layer}s

src/test/java/org/numenta/nupic/ConnectionsTest.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,23 @@
2727
import org.numenta.nupic.util.ArrayUtils;
2828
import org.numenta.nupic.util.MersenneTwister;
2929

30+
import com.cedarsoftware.util.DeepEquals;
31+
3032

3133
public class ConnectionsTest {
34+
@Test
35+
public void testCopy() {
36+
Parameters retVal = Parameters.getTemporalDefaultParameters();
37+
retVal.set(KEY.COLUMN_DIMENSIONS, new int[] { 32 });
38+
retVal.set(KEY.CELLS_PER_COLUMN, 4);
39+
40+
Connections connections = new Connections();
41+
42+
retVal.apply(connections);
43+
TemporalMemory.init(connections);
44+
45+
assertTrue(DeepEquals.deepEquals(connections, connections.copy()));
46+
}
3247

3348
@Test
3449
public void testCreateSegment() {
@@ -574,7 +589,7 @@ public void testGetPrintString() {
574589
TemporalMemory.init(con);
575590

576591
String output = con.getPrintString();
577-
assertEquals(1370, output.length());
592+
assertEquals(1371, output.length());
578593

579594
Set<String> fieldSet = Parameters.getEncoderDefaultParameters().keys().stream().
580595
map(k -> k.getFieldName()).collect(Collectors.toCollection(LinkedHashSet::new));

src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ private void initSP() {
9898
parameters.apply(mem);
9999
sp.init(mem);
100100
}
101-
101+
102102
@Test
103103
public void confirmSPConstruction() {
104104
setupParameters();

src/test/java/org/numenta/nupic/network/NetworkTest.java

+27-3
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ public void testBasicNetworkRunAWhileThenHalt() {
548548
@Test
549549
public void testRegionHierarchies() {
550550
Parameters p = NetworkTestHarness.getParameters();
551+
p.setPotentialRadius(16);
551552
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
552553
p.set(KEY.RANDOM, new MersenneTwister(42));
553554

@@ -1007,6 +1008,29 @@ public void testObservableWithCoordinateEncoder_NEGATIVE() {
10071008
assertTrue(hasErrors(tester));
10081009
}
10091010

1011+
@Test
1012+
public void testPotentialRadiusFollowsInputWidth() {
1013+
Parameters p = NetworkTestHarness.getParameters();
1014+
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
1015+
p.set(KEY.INPUT_DIMENSIONS, new int[] { 200 });
1016+
p.set(KEY.RANDOM, new MersenneTwister(42));
1017+
1018+
Network network = Network.create("test network", p)
1019+
.add(Network.createRegion("r1")
1020+
.add(Network.createLayer("2", p)
1021+
.add(Anomaly.create())
1022+
.add(new TemporalMemory())
1023+
.add(new SpatialPooler())
1024+
.close()));
1025+
1026+
Region r1 = network.lookup("r1");
1027+
Layer<?> layer2 = r1.lookup("2");
1028+
1029+
int width = layer2.calculateInputWidth();
1030+
assertEquals(200, width);
1031+
assertEquals(200, layer2.getConnections().getPotentialRadius());
1032+
}
1033+
10101034
///////////////////////////////////////////////////////////////////////////////////
10111035
// Tests of Calculate Input Width for inter-regional and inter-layer calcs //
10121036
///////////////////////////////////////////////////////////////////////////////////
@@ -1063,7 +1087,6 @@ public void testCalculateInputWidth_NoPrevLayer_UpstreamRegion_without_TM() {
10631087

10641088
int width = layer2.calculateInputWidth();
10651089
assertEquals(2048, width);
1066-
10671090
}
10681091

10691092
@Test
@@ -1077,7 +1100,6 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andTM() {
10771100
.add(Network.createLayer("2", p)
10781101
.add(Anomaly.create())
10791102
.add(new TemporalMemory())
1080-
//.add(new SpatialPooler())
10811103
.close()));
10821104

10831105
Region r1 = network.lookup("r1");
@@ -1098,14 +1120,15 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andSPTM() {
10981120
.add(Network.createLayer("2", p)
10991121
.add(Anomaly.create())
11001122
.add(new TemporalMemory())
1101-
.add(new SpatialPooler())
1123+
.add(new SpatialPooler())
11021124
.close()));
11031125

11041126
Region r1 = network.lookup("r1");
11051127
Layer<?> layer2 = r1.lookup("2");
11061128

11071129
int width = layer2.calculateInputWidth();
11081130
assertEquals(8, width);
1131+
assertEquals(8, layer2.getConnections().getPotentialRadius());
11091132
}
11101133

11111134
@Test
@@ -1126,6 +1149,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andNoTM() {
11261149

11271150
int width = layer2.calculateInputWidth();
11281151
assertEquals(8, width);
1152+
assertEquals(8, layer2.getConnections().getPotentialRadius());
11291153
}
11301154

11311155
@Test

src/test/java/org/numenta/nupic/network/NetworkTestHarness.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ public static Parameters getParameters() {
218218
parameters.set(KEY.CELLS_PER_COLUMN, 6);
219219

220220
//SpatialPooler specific
221-
parameters.set(KEY.POTENTIAL_RADIUS, 12);//3
221+
parameters.set(KEY.POTENTIAL_RADIUS, -1);//3
222222
parameters.set(KEY.POTENTIAL_PCT, 0.5);//0.5
223223
parameters.set(KEY.GLOBAL_INHIBITION, false);
224224
parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);

0 commit comments

Comments
 (0)