From c19d15ebf5f0922e9e807652845cda8685c92911 Mon Sep 17 00:00:00 2001 From: ago109 Date: Wed, 24 Jul 2024 14:41:36 -0400 Subject: [PATCH] added warning check to bernoulli, some cleanup --- .../input_encoders/bernoulliCell.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 61e740316..fa036302b 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -3,6 +3,8 @@ from jax import numpy as jnp, random, jit from ngclearn.utils import tensorstats from functools import partial +from ngcsimlib.deprecators import deprecate_args +from ngcsimlib.logger import info, warn @jit def _update_times(t, s, tols): @@ -79,15 +81,15 @@ class BernoulliCell(JaxComponent): n_units: number of cellular entities (neural population size) - max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.) + target_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.) """ - # Define Functions - def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs): + @deprecate_args(target_freq="max_freq") + def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs): super().__init__(name, **kwargs) ## Constrained Bernoulli meta-parameters - self.max_freq = max_freq ## maximum frequency (in Hertz/Hz) + self.target_freq = target_freq ## maximum frequency (in Hertz/Hz) ## Layer Size Setup self.batch_size = batch_size @@ -99,12 +101,26 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs): self.outputs = Compartment(restVals, display_name="Spikes") # output compartment self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike + def validate(self, dt, **validation_kwargs): + ## check for unstable combinations of dt and target-frequency meta-params + valid = super().validate(**validation_kwargs) + events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability + if events_per_timestep > 1.: + valid = False + warn( + f"{self.name} will be unable to make as many temporal events as " + f"requested! ({events_per_timestep} events/timestep) Unstable " + f"combination of dt = {dt} and target_freq = {self.target_freq} " + f"being used!" + ) + return valid + @staticmethod - def _advance_state(t, dt, max_freq, key, inputs, tols): + def _advance_state(t, dt, target_freq, key, inputs, tols): key, *subkeys = random.split(key, 2) - if max_freq > 0.: + if target_freq > 0.: outputs = _sample_constrained_bernoulli( ## sample Bernoulli w/ target rate - subkeys[0], data=inputs, dt=dt, fmax=max_freq + subkeys[0], data=inputs, dt=dt, fmax=target_freq ) else: outputs = _sample_bernoulli(subkeys[0], data=inputs)