3
3
from jax import numpy as jnp , random , jit
4
4
from ngclearn .utils import tensorstats
5
5
from functools import partial
6
+ from ngcsimlib .deprecators import deprecate_args
7
+ from ngcsimlib .logger import info , warn
6
8
7
9
@jit
8
10
def _update_times (t , s , tols ):
@@ -79,15 +81,15 @@ class BernoulliCell(JaxComponent):
79
81
80
82
n_units: number of cellular entities (neural population size)
81
83
82
- max_freq : maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
84
+ target_freq : maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
83
85
"""
84
86
85
- # Define Functions
86
- def __init__ (self , name , n_units , max_freq = 63.75 , batch_size = 1 , ** kwargs ):
87
+ @ deprecate_args ( target_freq = "max_freq" )
88
+ def __init__ (self , name , n_units , target_freq = 63.75 , batch_size = 1 , ** kwargs ):
87
89
super ().__init__ (name , ** kwargs )
88
90
89
91
## Constrained Bernoulli meta-parameters
90
- self .max_freq = max_freq ## maximum frequency (in Hertz/Hz)
92
+ self .target_freq = target_freq ## maximum frequency (in Hertz/Hz)
91
93
92
94
## Layer Size Setup
93
95
self .batch_size = batch_size
@@ -99,12 +101,26 @@ def __init__(self, name, n_units, max_freq=63.75, batch_size=1, **kwargs):
99
101
self .outputs = Compartment (restVals , display_name = "Spikes" ) # output compartment
100
102
self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
101
103
104
+ def validate (self , dt , ** validation_kwargs ):
105
+ ## check for unstable combinations of dt and target-frequency meta-params
106
+ valid = super ().validate (** validation_kwargs )
107
+ events_per_timestep = (dt / 1000. ) * self .target_freq ## compute scaled probability
108
+ if events_per_timestep > 1. :
109
+ valid = False
110
+ warn (
111
+ f"{ self .name } will be unable to make as many temporal events as "
112
+ f"requested! ({ events_per_timestep } events/timestep) Unstable "
113
+ f"combination of dt = { dt } and target_freq = { self .target_freq } "
114
+ f"being used!"
115
+ )
116
+ return valid
117
+
102
118
@staticmethod
103
- def _advance_state (t , dt , max_freq , key , inputs , tols ):
119
+ def _advance_state (t , dt , target_freq , key , inputs , tols ):
104
120
key , * subkeys = random .split (key , 2 )
105
- if max_freq > 0. :
121
+ if target_freq > 0. :
106
122
outputs = _sample_constrained_bernoulli ( ## sample Bernoulli w/ target rate
107
- subkeys [0 ], data = inputs , dt = dt , fmax = max_freq
123
+ subkeys [0 ], data = inputs , dt = dt , fmax = target_freq
108
124
)
109
125
else :
110
126
outputs = _sample_bernoulli (subkeys [0 ], data = inputs )
0 commit comments