Skip to content

Commit 16588fa

Browse files
committed
added spike-reset/snap-back to fn-cell
1 parent e6d7317 commit 16588fa

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

ngclearn/components/neurons/spiking/fitzhughNagumoCell.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,17 @@ class FitzhughNagumoCell(JaxComponent):
113113
114114
gamma: power-term divisor (Default: 3.)
115115
116-
v_thr: voltage/membrane threshold (to obtain action potentials in terms
117-
of binary spikes)
118-
119116
v0: initial condition / reset for voltage
120117
121118
w0: initial condition / reset for recovery
122119
120+
v_thr: voltage/membrane threshold (to obtain action potentials in terms
121+
of binary spikes)
122+
123+
spike_reset: if True, once voltage crosses threshold, then dynamics
124+
of voltage and recovery are reset/snapped to initial conditions
125+
(default: False)
126+
123127
integration_type: type of integration to use for this cell's dynamics;
124128
current supported forms include "euler" (Euler/RK-1 integration)
125129
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
@@ -131,7 +135,7 @@ class FitzhughNagumoCell(JaxComponent):
131135

132136
# Define Functions
133137
def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
134-
beta=0.8, gamma=3., v_thr=1.07, v0=0., w0=0.,
138+
beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False,
135139
integration_type="euler", **kwargs):
136140
super().__init__(name, **kwargs)
137141

@@ -150,6 +154,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
150154
self.v0 = v0 ## initial membrane potential/voltage condition
151155
self.w0 = w0 ## initial w-parameter condition
152156
self.v_thr = v_thr
157+
self.spike_reset = spike_reset
153158

154159
## Layer Size Setup
155160
self.batch_size = 1
@@ -164,10 +169,13 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
164169
self.tols = Compartment(restVals) ## time-of-last-spike
165170

166171
@staticmethod
167-
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, alpha, beta, gamma,
168-
intgFlag, j, v, w, tols):
172+
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha,
173+
beta, gamma, intgFlag, j, v, w, tols):
169174
v, w, s = _run_cell(dt, j * R_m, v, w, v_thr, tau_m, tau_w, alpha, beta,
170175
gamma, intgFlag)
176+
if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
177+
v = v * (1. - s) + s * v0
178+
w = w * (1. - s) + s * w0
171179
tols = _update_times(t, s, tols)
172180
return j, v, w, s, tols
173181

@@ -220,6 +228,8 @@ def help(cls): ## component help function
220228
"resist_m": "Membrane resistance value",
221229
"tau_w": "Recovery variable time constant",
222230
"v_thr": "Base voltage threshold value",
231+
"spike_reset": "Should voltage/recover be snapped to initial "
232+
"condition(s) if spike emitted?",
223233
"alpha": "Dimensionless recovery variable shift factor `a",
224234
"beta": "Dimensionless recovery variable scale factor `b`",
225235
"gamma": "Power-term divisor constant",

0 commit comments

Comments
 (0)