From 1ddd86d85bc348043c374bdd580b30d3fa5e5802 Mon Sep 17 00:00:00 2001 From: ago109 Date: Tue, 23 Jul 2024 14:53:47 -0400 Subject: [PATCH] updated lif-cell to use units/tags and minor cleanup and edits --- .../components/neurons/spiking/LIFCell.py | 166 ++++++++---------- 1 file changed, 71 insertions(+), 95 deletions(-) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index c754e208e..f44a93187 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -1,11 +1,14 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats +from ngcsimlib.deprecators import deprecate_args from ngclearn import resolver, Component, Compartment from ngclearn.components.jaxComponent import JaxComponent from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngclearn.utils.surrogate_fx import secant_lif_estimator, arctan_estimator, triangular_estimator +from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, + triangular_estimator, + straight_through_estimator) @jit def _update_times(t, s, tols): @@ -25,12 +28,6 @@ def _update_times(t, s, tols): _tols = (1. - s) * tols + (s * t) return _tols -# @jit -# def _modify_current(j, dt, tau_m, R_m): -# ## electrical current re-scaling co-routine -# jScale = tau_m/dt ## <-- this anti-scale counter-balances form of ODE used in this cell -# return (j * R_m) * jScale - @jit def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask @@ -47,44 +44,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper #@partial(jit, static_argnums=[7, 8, 9, 10, 11, 12]) def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset, v_decay, refract_T, integType=0): - """ - Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics. - - Args: - dt: integration time constant (milliseconds, or ms) - - j: electrical current value - - v: membrane potential (voltage, in milliVolts or mV) value (at t) - - v_thr: base voltage threshold value (in mV) - - v_theta: threshold shift (homeostatic) variable (at t) - - rfr: refractory variable vector (one per neuronal cell) - - skey: PRNG key which, if not None, will trigger a single-spike constraint - (i.e., only one spike permitted to emit per single step of time); - specifically used to randomly sample one of the possible action - potentials to be an emitted spike - - tau_m: cell membrane time constant - - v_rest: membrane resting potential (in mV) - - v_reset: membrane reset potential (in mV) -- upon occurrence of a spike, - a neuronal cell's membrane potential will be set to this value - - v_decay: strength of voltage leak (Default: 1.) - - refract_T: (relative) refractory time period (in ms; Default - value is 1 ms) - - integType: integer indicating type of integration to use - - Returns: - voltage(t+dt), spikes, raw spikes, updated refactory variables - """ + ### Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics. _v_thr = v_theta + v_thr ## calc present voltage threshold #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask ## update voltage / membrane potential @@ -114,24 +74,7 @@ def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset, @partial(jit, static_argnums=[3, 4]) def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05): - """ - Runs homeostatic threshold update dynamics one step (via Euler integration). - - Args: - dt: integration time constant (milliseconds, or ms) - - v_theta: current value of homeostatic threshold variable - - s: current spikes (at t) - - tau_theta: homeostatic threshold time constant - - theta_plus: physical increment to be applied to any threshold value if - a spike was emitted - - Returns: - updated homeostatic threshold variable - """ + ### Runs homeostatic threshold update dynamics one step (via Euler integration). #theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7) #theta_plus = 0.05 #_V_theta = V_theta * theta_decay + S * theta_plus @@ -205,11 +148,10 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell at an increase in computational cost (and simulation time) """ - # Define Functions + @deprecate_args(thr_jitter=None) def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05, - refract_time=5., thr_jitter=0., one_spike=False, - integration_type="euler", **kwargs): + refract_time=5., one_spike=False, integration_type="euler", **kwargs): super().__init__(name, **kwargs) ## Integration properties @@ -218,7 +160,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., ## membrane parameter setup (affects ODE integration) self.tau_m = tau_m ## membrane time constant - self.R_m = resist_m ## resistance value + self.resist_m = resist_m ## resistance value self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step self.v_rest = v_rest #-65. # mV @@ -226,7 +168,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF) ## basic asserts to prevent neuronal dynamics breaking... #assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify... - assert self.R_m > 0. + assert self.resist_m > 0. self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off) self.theta_plus = theta_plus #0.05 ## threshold increment self.refract_T = refract_time #5. # 2. ## refractory period # ms @@ -237,41 +179,45 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.n_units = n_units ## set up surrogate function for spike emission - self.spike_fx, self.d_spike_fx = secant_lif_estimator() - #self.spike_fx, self.d_spike_fx = arctan_estimator() # - #self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator() + surrgoate_type = "secant_lif" + if surrgoate_type == "secant_lif": + self.spike_fx, self.d_spike_fx = secant_lif_estimator() + elif surrgoate_type == "arctan": + self.spike_fx, self.d_spike_fx = arctan_estimator() + elif surrgoate_type == "triangular": + self.spike_fx, self.d_spike_fx = triangular_estimator() + else: ## default is the straight-through estimator (STE) + self.spike_fx, self.d_spike_fx = straight_through_estimator() + ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) - thr0 = 0. - if thr_jitter > 0.: - key, subkey = random.split(self.key.value) - thr0 = random.uniform(subkey, (1, n_units), minval=-thr_jitter, - maxval=thr_jitter, dtype=jnp.float32) - self.j = Compartment(restVals) - self.v = Compartment(restVals + self.v_rest) - self.s = Compartment(restVals) - self.s_raw = Compartment(restVals) - self.rfr = Compartment(restVals + self.refract_T) - self.thr_theta = Compartment(restVals + thr0) - self.tols = Compartment(restVals) ## time-of-last-spike - self.surrogate = Compartment(restVals + 1.) ## surrogate signal + self.j = Compartment(restVals, display_name="Current", units="mA") + self.v = Compartment(restVals + self.v_rest, + display_name="Voltage", units="mV") + self.s = Compartment(restVals, display_name="Spikes") + self.s_raw = Compartment(restVals, display_name="Raw Spike Pulses") + self.rfr = Compartment(restVals + self.refract_T, + display_name="Refractory Time Period", units="ms") + self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift", + units="mV") + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", + units="ms") ## time-of-last-spike + self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") @staticmethod - def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, + def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx, key, j, v, s, rfr, thr_theta, tols): skey = None ## this is an empty dkey if single_spike mode turned off if one_spike: key, skey = random.split(key, 2) ## run one integration step for neuronal dynamics - j = j * R_m - #surrogate = d_spike_fx(v, thr + thr_theta) + j = j * resist_m v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey, - tau_m, v_rest, v_reset, v_decay, refract_T, - intgFlag) + tau_m, v_rest, v_reset, v_decay, + refract_T, intgFlag) surrogate = d_spike_fx(v, thr + thr_theta) - #surrogate = d_spike_fx(j, thr + thr_theta) if tau_theta > 0.: ## run one integration step for threshold dynamics thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus) @@ -310,22 +256,53 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate): self.s.set(s) self.s_raw.set(s_raw) self.rfr.set(rfr) - #self.thr_theta.set(thr_theta) self.tols.set(tols) self.surrogate.set(surrogate) def save(self, directory, **kwargs): + ## do a protected save of constants, depending on whether they are floats or arrays + tau_m = (self.tau_m if isinstance(self.tau_m, float) + else jnp.ones([[self.tau_m]])) + thr = (self.thr if isinstance(self.thr, float) + else jnp.ones([[self.thr]])) + v_rest = (self.v_rest if isinstance(self.v_rest, float) + else jnp.ones([[self.v_rest]])) + v_reset = (self.v_reset if isinstance(self.v_reset, float) + else jnp.ones([[self.v_reset]])) + v_decay = (self.v_decay if isinstance(self.v_decay, float) + else jnp.ones([[self.v_decay]])) + resist_m = (self.resist_m if isinstance(self.resist_m, float) + else jnp.ones([[self.resist_m]])) + tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) + else jnp.ones([[self.tau_theta]])) + theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) + else jnp.ones([[self.theta_plus]])) + file_name = directory + "/" + self.name + ".npz" jnp.savez(file_name, threshold_theta=self.thr_theta.value, + tau_m=tau_m, thr=thr, v_rest=v_rest, + v_reset=v_reset, v_decay=v_decay, + resist_m=resist_m, tau_theta=tau_theta, + theta_plus=theta_plus, key=self.key.value) def load(self, directory, seeded=False, **kwargs): file_name = directory + "/" + self.name + ".npz" data = jnp.load(file_name) - self.thr_theta.set( data['threshold_theta'] ) - if seeded == True: - self.key.set( data['key'] ) + self.thr_theta.set(data['thr_theta']) + ## constants loaded in + self.tau_m = data['tau_m'] + self.thr = data['thr'] + self.v_rest = data['v_rest'] + self.v_reset = data['v_reset'] + self.v_decay = data['v_decay'] + self.resist_m = data['resist_m'] + self.tau_theta = data['tau_theta'] + self.theta_plus = data['theta_plus'] + + if seeded: + self.key.set(data['key']) @classmethod def help(cls): ## component help function @@ -356,7 +333,6 @@ def help(cls): ## component help function "tau_theta": "Threshold/homoestatic increment time constant", "theta_plus": "Amount to increment threshold by upon occurrence of spike", "refract_time": "Length of relative refractory period (ms)", - "thr_jitter": "Scale of random uniform noise to apply to initial condition of threshold", "one_spike": "Should only one spike be sampled/allowed to emit at any given time step?", "integration_type": "Type of numerical integration to use for the cell dynamics" }