Skip to content

Commit

Permalink
added spike-reset/snap-back to fn-cell
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jun 28, 2024
1 parent e6d7317 commit 16588fa
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions ngclearn/components/neurons/spiking/fitzhughNagumoCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ class FitzhughNagumoCell(JaxComponent):
gamma: power-term divisor (Default: 3.)
v_thr: voltage/membrane threshold (to obtain action potentials in terms
of binary spikes)
v0: initial condition / reset for voltage
w0: initial condition / reset for recovery
v_thr: voltage/membrane threshold (to obtain action potentials in terms
of binary spikes)
spike_reset: if True, once voltage crosses threshold, then dynamics
of voltage and recovery are reset/snapped to initial conditions
(default: False)
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
Expand All @@ -131,7 +135,7 @@ class FitzhughNagumoCell(JaxComponent):

# Define Functions
def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
beta=0.8, gamma=3., v_thr=1.07, v0=0., w0=0.,
beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False,
integration_type="euler", **kwargs):
super().__init__(name, **kwargs)

Expand All @@ -150,6 +154,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
self.v0 = v0 ## initial membrane potential/voltage condition
self.w0 = w0 ## initial w-parameter condition
self.v_thr = v_thr
self.spike_reset = spike_reset

## Layer Size Setup
self.batch_size = 1
Expand All @@ -164,10 +169,13 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
self.tols = Compartment(restVals) ## time-of-last-spike

@staticmethod
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, alpha, beta, gamma,
intgFlag, j, v, w, tols):
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha,
beta, gamma, intgFlag, j, v, w, tols):
v, w, s = _run_cell(dt, j * R_m, v, w, v_thr, tau_m, tau_w, alpha, beta,
gamma, intgFlag)
if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
v = v * (1. - s) + s * v0
w = w * (1. - s) + s * w0
tols = _update_times(t, s, tols)
return j, v, w, s, tols

Expand Down Expand Up @@ -220,6 +228,8 @@ def help(cls): ## component help function
"resist_m": "Membrane resistance value",
"tau_w": "Recovery variable time constant",
"v_thr": "Base voltage threshold value",
"spike_reset": "Should voltage/recover be snapped to initial "
"condition(s) if spike emitted?",
"alpha": "Dimensionless recovery variable shift factor `a",
"beta": "Dimensionless recovery variable scale factor `b`",
"gamma": "Power-term divisor constant",
Expand Down

0 comments on commit 16588fa

Please sign in to comment.