Skip to content

Commit

Permalink
added warning check to bernoulli, some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 24, 2024
1 parent 05ea912 commit c19d15e
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from jax import numpy as jnp, random, jit
from ngclearn.utils import tensorstats
from functools import partial
from ngcsimlib.deprecators import deprecate_args
from ngcsimlib.logger import info, warn

@jit
def _update_times(t, s, tols):
Expand Down Expand Up @@ -79,15 +81,15 @@ class BernoulliCell(JaxComponent):
n_units: number of cellular entities (neural population size)
max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
"""

# Define Functions
def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
@deprecate_args(target_freq="max_freq")
def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## Constrained Bernoulli meta-parameters
self.max_freq = max_freq ## maximum frequency (in Hertz/Hz)
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)

## Layer Size Setup
self.batch_size = batch_size
Expand All @@ -99,12 +101,26 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike

def validate(self, dt, **validation_kwargs):
## check for unstable combinations of dt and target-frequency meta-params
valid = super().validate(**validation_kwargs)
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
if events_per_timestep > 1.:
valid = False
warn(
f"{self.name} will be unable to make as many temporal events as "
f"requested! ({events_per_timestep} events/timestep) Unstable "
f"combination of dt = {dt} and target_freq = {self.target_freq} "
f"being used!"
)
return valid

@staticmethod
def _advance_state(t, dt, max_freq, key, inputs, tols):
def _advance_state(t, dt, target_freq, key, inputs, tols):
key, *subkeys = random.split(key, 2)
if max_freq > 0.:
if target_freq > 0.:
outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate
subkeys[0], data=inputs, dt=dt, fmax=max_freq
subkeys[0], data=inputs, dt=dt, fmax=target_freq
)
else:
outputs = _sample_bernoulli(subkeys[0], data=inputs)
Expand Down

0 comments on commit c19d15e

Please sign in to comment.