22from ngclearn .components .jaxComponent import JaxComponent
33from jax import numpy as jnp , random , jit
44from ngclearn .utils import tensorstats
5+ from functools import partial
56
67@jit
78def _update_times (t , s , tols ):
@@ -37,9 +38,33 @@ def _sample_bernoulli(dkey, data):
3738 s_t = random .bernoulli (dkey , p = data ).astype (jnp .float32 )
3839 return s_t
3940
41+ @partial (jit , static_argnums = [3 ])
42+ def _sample_constrained_bernoulli (dkey , data , dt , fmax = 63.75 ):
43+ """
44+ Samples a Bernoulli spike train on-the-fly that is constrained to emit
45+ at a particular rate over a time window.
46+
47+ Args:
48+ dkey: JAX key to drive stochasticity/noise
49+
50+ data: sensory data (vector/matrix)
51+
52+ dt: integration time constant
53+
54+ fmax: maximum frequency (Hz)
55+
56+ Returns:
57+ binary spikes
58+ """
59+ pspike = data * (dt / 1000. ) * fmax
60+ eps = random .uniform (dkey , data .shape , minval = 0. , maxval = 1. , dtype = jnp .float32 )
61+ s_t = (eps < pspike ).astype (jnp .float32 )
62+ return s_t
63+
4064class BernoulliCell (JaxComponent ):
4165 """
42- A Bernoulli cell that produces Bernoulli-distributed spikes on-the-fly.
66+ A Bernoulli cell that produces variations of Bernoulli-distributed spikes
67+ on-the-fly (including constrained-rate trains).
4368
4469 | --- Cell Input Compartments: ---
4570 | inputs - input (takes in external signals)
@@ -53,12 +78,17 @@ class BernoulliCell(JaxComponent):
5378 name: the string name of this cell
5479
5580 n_units: number of cellular entities (neural population size)
81+
82+ max_freq: maximum frequency (in Hertz) of this Bernoulli spike train (must be > 0.)
5683 """
5784
5885 # Define Functions
59- def __init__ (self , name , n_units , batch_size = 1 , ** kwargs ):
86+ def __init__ (self , name , n_units , max_freq = 63.75 , batch_size = 1 , ** kwargs ):
6087 super ().__init__ (name , ** kwargs )
6188
89+ ## Constrained Bernoulli meta-parameters
90+ self .max_freq = max_freq ## maximum frequency (in Hertz/Hz)
91+
6292 ## Layer Size Setup
6393 self .batch_size = batch_size
6494 self .n_units = n_units
@@ -70,11 +100,16 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
70100 self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" , units = "ms" ) # time of last spike
71101
72102 @staticmethod
73- def _advance_state (t , key , inputs , tols ):
103+ def _advance_state (t , dt , max_freq , key , inputs , tols ):
74104 key , * subkeys = random .split (key , 2 )
75- outputs = _sample_bernoulli (subkeys [0 ], data = inputs )
76- timeOfLastSpike = _update_times (t , outputs , tols )
77- return outputs , timeOfLastSpike , key
105+ if max_freq > 0. :
106+ outputs = _sample_constrained_bernoulli ( ## sample Bernoulli w/ target rate
107+ subkeys [0 ], data = inputs , dt = dt , fmax = max_freq
108+ )
109+ else :
110+ outputs = _sample_bernoulli (subkeys [0 ], data = inputs )
111+ tols = _update_times (t , outputs , tols )
112+ return outputs , tols , key
78113
79114 @resolver (_advance_state )
80115 def advance_state (self , outputs , tols , key ):
0 commit comments