Skip to content

Commit

Permalink
fixed validation fun in bern/poiss
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 25, 2024
1 parent 223d3c0 commit 9afaadf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
7 changes: 5 additions & 2 deletions ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,12 @@ def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike

def validate(self, dt, **validation_kwargs):
## check for unstable combinations of dt and target-frequency meta-params
def validate(self, dt=None, **validation_kwargs):
valid = super().validate(**validation_kwargs)
if dt is None:
warn(f"{self.name} requires a validation kwarg of `dt`")
return False
## check for unstable combinations of dt and target-frequency meta-params
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
if events_per_timestep > 1.:
valid = False
Expand Down
9 changes: 6 additions & 3 deletions ngclearn/components/input_encoders/poissonCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
random.uniform(subkey, (self.batch_size, self.n_units), minval=0.,
maxval=1.))

def validate(self, dt, **validation_kwargs):
## check for unstable combinations of dt and target-frequency meta-params
def validate(self, dt=None, **validation_kwargs):
valid = super().validate(**validation_kwargs)
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
if dt is None:
warn(f"{self.name} requires a validation kwarg of `dt`")
return False
## check for unstable combinations of dt and target-frequency meta-params
events_per_timestep = (dt / 1000.) * self.target_freq ## compute scaled probability
if events_per_timestep > 1.:
valid = False
warn(
Expand Down

0 comments on commit 9afaadf

Please sign in to comment.