33from jax import numpy as jnp , random , jit
44from ngclearn .utils import tensorstats
55from functools import partial
6+ from ngcsimlib .deprecators import deprecate_args
7+ from ngcsimlib .logger import info , warn
68
79@jit
810def _update_times (t , s , tols ):
@@ -79,15 +81,15 @@ class BernoulliCell(JaxComponent):
7981
8082 n_units: number of cellular entities (neural population size)
8183
82- max_freq : maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
84+ target_freq : maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
8385 """
8486
85- # Define Functions
86- def __init__ (self , name , n_units , max_freq = 63.75 , batch_size = 1 , ** kwargs ):
87+ @ deprecate_args ( target_freq = "max_freq" )
88+ def __init__ (self , name , n_units , target_freq = 63.75 , batch_size = 1 , ** kwargs ):
8789 super ().__init__ (name , ** kwargs )
8890
8991 ## Constrained Bernoulli meta-parameters
90- self .max_freq = max_freq ## maximum frequency (in Hertz/Hz)
92+ self .target_freq = target_freq ## maximum frequency (in Hertz/Hz)
9193
9294 ## Layer Size Setup
9395 self .batch_size = batch_size
@@ -99,12 +101,26 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
99101 self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
100102 self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
101103
104+ def validate (self , dt , ** validation_kwargs ):
105+ ## check for unstable combinations of dt and target-frequency meta-params
106+ valid = super ().validate (** validation_kwargs )
107+ events_per_timestep = (dt / 1000. ) * self .target_freq ## compute scaled probability
108+ if events_per_timestep > 1. :
109+ valid = False
110+ warn (
111+ f"{ self .name } will be unable to make as many temporal events as "
112+ f"requested! ({ events_per_timestep } events/timestep) Unstable "
113+ f"combination of dt = { dt } and target_freq = { self .target_freq } "
114+ f"being used!"
115+ )
116+ return valid
117+
102118 @staticmethod
103- def _advance_state (t , dt , max_freq , key , inputs , tols ):
119+ def _advance_state (t , dt , target_freq , key , inputs , tols ):
104120 key , * subkeys = random .split (key , 2 )
105- if max_freq > 0. :
121+ if target_freq > 0. :
106122 outputs = _sample_constrained_bernoulli ( ## sample Bernoulli w/ target rate
107- subkeys [0 ], data = inputs , dt = dt , fmax = max_freq
123+ subkeys [0 ], data = inputs , dt = dt , fmax = target_freq
108124 )
109125 else :
110126 outputs = _sample_bernoulli (subkeys [0 ], data = inputs )
0 commit comments