Skip to content

Commit

Permalink
updated lif-cell to use units/tags and minor cleanup and edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 23, 2024
1 parent 23eb7ed commit 1ddd86d
Showing 1 changed file with 71 additions and 95 deletions.
166 changes: 71 additions & 95 deletions ngclearn/components/neurons/spiking/LIFCell.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from jax import numpy as jnp, random, jit, nn
from functools import partial
from ngclearn.utils import tensorstats
from ngcsimlib.deprecators import deprecate_args
from ngclearn import resolver, Component, Compartment
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import secant_lif_estimator, arctan_estimator, triangular_estimator
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
triangular_estimator,
straight_through_estimator)

@jit
def _update_times(t, s, tols):
Expand All @@ -25,12 +28,6 @@ def _update_times(t, s, tols):
_tols = (1. - s) * tols + (s * t)
return _tols

# @jit
# def _modify_current(j, dt, tau_m, R_m):
# ## electrical current re-scaling co-routine
# jScale = tau_m/dt ## <-- this anti-scale counter-balances form of ODE used in this cell
# return (j * R_m) * jScale

@jit
def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
Expand All @@ -47,44 +44,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
#@partial(jit, static_argnums=[7, 8, 9, 10, 11, 12])
def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset,
v_decay, refract_T, integType=0):
"""
Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
Args:
dt: integration time constant (milliseconds, or ms)
j: electrical current value
v: membrane potential (voltage, in milliVolts or mV) value (at t)
v_thr: base voltage threshold value (in mV)
v_theta: threshold shift (homeostatic) variable (at t)
rfr: refractory variable vector (one per neuronal cell)
skey: PRNG key which, if not None, will trigger a single-spike constraint
(i.e., only one spike permitted to emit per single step of time);
specifically used to randomly sample one of the possible action
potentials to be an emitted spike
tau_m: cell membrane time constant
v_rest: membrane resting potential (in mV)
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
a neuronal cell's membrane potential will be set to this value
v_decay: strength of voltage leak (Default: 1.)
refract_T: (relative) refractory time period (in ms; Default
value is 1 ms)
integType: integer indicating type of integration to use
Returns:
voltage(t+dt), spikes, raw spikes, updated refactory variables
"""
### Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
_v_thr = v_theta + v_thr ## calc present voltage threshold
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
## update voltage / membrane potential
Expand Down Expand Up @@ -114,24 +74,7 @@ def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset,

@partial(jit, static_argnums=[3, 4])
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
"""
Runs homeostatic threshold update dynamics one step (via Euler integration).
Args:
dt: integration time constant (milliseconds, or ms)
v_theta: current value of homeostatic threshold variable
s: current spikes (at t)
tau_theta: homeostatic threshold time constant
theta_plus: physical increment to be applied to any threshold value if
a spike was emitted
Returns:
updated homeostatic threshold variable
"""
### Runs homeostatic threshold update dynamics one step (via Euler integration).
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
#theta_plus = 0.05
#_V_theta = V_theta * theta_decay + S * theta_plus
Expand Down Expand Up @@ -205,11 +148,10 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
at an increase in computational cost (and simulation time)
"""

# Define Functions
@deprecate_args(thr_jitter=None)
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05,
refract_time=5., thr_jitter=0., one_spike=False,
integration_type="euler", **kwargs):
refract_time=5., one_spike=False, integration_type="euler", **kwargs):
super().__init__(name, **kwargs)

## Integration properties
Expand All @@ -218,15 +160,15 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,

## membrane parameter setup (affects ODE integration)
self.tau_m = tau_m ## membrane time constant
self.R_m = resist_m ## resistance value
self.resist_m = resist_m ## resistance value
self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step

self.v_rest = v_rest #-65. # mV
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
self.v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF)
## basic asserts to prevent neuronal dynamics breaking...
#assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify...
assert self.R_m > 0.
assert self.resist_m > 0.
self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off)
self.theta_plus = theta_plus #0.05 ## threshold increment
self.refract_T = refract_time #5. # 2. ## refractory period # ms
Expand All @@ -237,41 +179,45 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.n_units = n_units

## set up surrogate function for spike emission
self.spike_fx, self.d_spike_fx = secant_lif_estimator()
#self.spike_fx, self.d_spike_fx = arctan_estimator() #
#self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()
surrgoate_type = "secant_lif"
if surrgoate_type == "secant_lif":
self.spike_fx, self.d_spike_fx = secant_lif_estimator()
elif surrgoate_type == "arctan":
self.spike_fx, self.d_spike_fx = arctan_estimator()
elif surrgoate_type == "triangular":
self.spike_fx, self.d_spike_fx = triangular_estimator()
else: ## default is the straight-through estimator (STE)
self.spike_fx, self.d_spike_fx = straight_through_estimator()


## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
thr0 = 0.
if thr_jitter > 0.:
key, subkey = random.split(self.key.value)
thr0 = random.uniform(subkey, (1, n_units), minval=-thr_jitter,
maxval=thr_jitter, dtype=jnp.float32)
self.j = Compartment(restVals)
self.v = Compartment(restVals + self.v_rest)
self.s = Compartment(restVals)
self.s_raw = Compartment(restVals)
self.rfr = Compartment(restVals + self.refract_T)
self.thr_theta = Compartment(restVals + thr0)
self.tols = Compartment(restVals) ## time-of-last-spike
self.surrogate = Compartment(restVals + 1.) ## surrogate signal
self.j = Compartment(restVals, display_name="Current", units="mA")
self.v = Compartment(restVals + self.v_rest,
display_name="Voltage", units="mV")
self.s = Compartment(restVals, display_name="Spikes")
self.s_raw = Compartment(restVals, display_name="Raw Spike Pulses")
self.rfr = Compartment(restVals + self.refract_T,
display_name="Refractory Time Period", units="ms")
self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift",
units="mV")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
units="ms") ## time-of-last-spike
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")

@staticmethod
def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T,
thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx,
key, j, v, s, rfr, thr_theta, tols):
skey = None ## this is an empty dkey if single_spike mode turned off
if one_spike:
key, skey = random.split(key, 2)
## run one integration step for neuronal dynamics
j = j * R_m
#surrogate = d_spike_fx(v, thr + thr_theta)
j = j * resist_m
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
tau_m, v_rest, v_reset, v_decay, refract_T,
intgFlag)
tau_m, v_rest, v_reset, v_decay,
refract_T, intgFlag)
surrogate = d_spike_fx(v, thr + thr_theta)
#surrogate = d_spike_fx(j, thr + thr_theta)
if tau_theta > 0.:
## run one integration step for threshold dynamics
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
Expand Down Expand Up @@ -310,22 +256,53 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
self.s.set(s)
self.s_raw.set(s_raw)
self.rfr.set(rfr)
#self.thr_theta.set(thr_theta)
self.tols.set(tols)
self.surrogate.set(surrogate)

def save(self, directory, **kwargs):
## do a protected save of constants, depending on whether they are floats or arrays
tau_m = (self.tau_m if isinstance(self.tau_m, float)
else jnp.ones([[self.tau_m]]))
thr = (self.thr if isinstance(self.thr, float)
else jnp.ones([[self.thr]]))
v_rest = (self.v_rest if isinstance(self.v_rest, float)
else jnp.ones([[self.v_rest]]))
v_reset = (self.v_reset if isinstance(self.v_reset, float)
else jnp.ones([[self.v_reset]]))
v_decay = (self.v_decay if isinstance(self.v_decay, float)
else jnp.ones([[self.v_decay]]))
resist_m = (self.resist_m if isinstance(self.resist_m, float)
else jnp.ones([[self.resist_m]]))
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
else jnp.ones([[self.tau_theta]]))
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
else jnp.ones([[self.theta_plus]]))

file_name = directory + "/" + self.name + ".npz"
jnp.savez(file_name,
threshold_theta=self.thr_theta.value,
tau_m=tau_m, thr=thr, v_rest=v_rest,
v_reset=v_reset, v_decay=v_decay,
resist_m=resist_m, tau_theta=tau_theta,
theta_plus=theta_plus,
key=self.key.value)

def load(self, directory, seeded=False, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.thr_theta.set( data['threshold_theta'] )
if seeded == True:
self.key.set( data['key'] )
self.thr_theta.set(data['thr_theta'])
## constants loaded in
self.tau_m = data['tau_m']
self.thr = data['thr']
self.v_rest = data['v_rest']
self.v_reset = data['v_reset']
self.v_decay = data['v_decay']
self.resist_m = data['resist_m']
self.tau_theta = data['tau_theta']
self.theta_plus = data['theta_plus']

if seeded:
self.key.set(data['key'])

@classmethod
def help(cls): ## component help function
Expand Down Expand Up @@ -356,7 +333,6 @@ def help(cls): ## component help function
"tau_theta": "Threshold/homoestatic increment time constant",
"theta_plus": "Amount to increment threshold by upon occurrence of spike",
"refract_time": "Length of relative refractory period (ms)",
"thr_jitter": "Scale of random uniform noise to apply to initial condition of threshold",
"one_spike": "Should only one spike be sampled/allowed to emit at any given time step?",
"integration_type": "Type of numerical integration to use for the cell dynamics"
}
Expand Down

0 comments on commit 1ddd86d

Please sign in to comment.