Skip to content

Commit c19d15e

Browse files
committed
added warning check to bernoulli, some cleanup
1 parent 05ea912 commit c19d15e

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from jax import numpy as jnp, random, jit
44
from ngclearn.utils import tensorstats
55
from functools import partial
6+
from ngcsimlib.deprecators import deprecate_args
7+
from ngcsimlib.logger import info, warn
68

79
@jit
810
def _update_times(t, s, tols):
@@ -79,15 +81,15 @@ class BernoulliCell(JaxComponent):
7981
8082
n_units: number of cellular entities (neural population size)
8183
82-
max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
84+
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
8385
"""
8486

85-
# Define Functions
86-
def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
87+
@deprecate_args(target_freq="max_freq")
88+
def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
8789
super().__init__(name, **kwargs)
8890

8991
## Constrained Bernoulli meta-parameters
90-
self.max_freq = max_freq ## maximum frequency (in Hertz/Hz)
92+
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
9193

9294
## Layer Size Setup
9395
self.batch_size = batch_size
@@ -99,12 +101,26 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
99101
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
100102
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
101103

104+
def validate(self, dt, **validation_kwargs):
105+
## check for unstable combinations of dt and target-frequency meta-params
106+
valid = super().validate(**validation_kwargs)
107+
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
108+
if events_per_timestep > 1.:
109+
valid = False
110+
warn(
111+
f"{self.name} will be unable to make as many temporal events as "
112+
f"requested! ({events_per_timestep} events/timestep) Unstable "
113+
f"combination of dt = {dt} and target_freq = {self.target_freq} "
114+
f"being used!"
115+
)
116+
return valid
117+
102118
@staticmethod
103-
def _advance_state(t, dt, max_freq, key, inputs, tols):
119+
def _advance_state(t, dt, target_freq, key, inputs, tols):
104120
key, *subkeys = random.split(key, 2)
105-
if max_freq > 0.:
121+
if target_freq > 0.:
106122
outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate
107-
subkeys[0], data=inputs, dt=dt, fmax=max_freq
123+
subkeys[0], data=inputs, dt=dt, fmax=target_freq
108124
)
109125
else:
110126
outputs = _sample_bernoulli(subkeys[0], data=inputs)

0 commit comments

Comments
 (0)