@@ -113,13 +113,17 @@ class FitzhughNagumoCell(JaxComponent):
113
113
114
114
gamma: power-term divisor (Default: 3.)
115
115
116
- v_thr: voltage/membrane threshold (to obtain action potentials in terms
117
- of binary spikes)
118
-
119
116
v0: initial condition / reset for voltage
120
117
121
118
w0: initial condition / reset for recovery
122
119
120
+ v_thr: voltage/membrane threshold (to obtain action potentials in terms
121
+ of binary spikes)
122
+
123
+ spike_reset: if True, once voltage crosses threshold, then dynamics
124
+ of voltage and recovery are reset/snapped to initial conditions
125
+ (default: False)
126
+
123
127
integration_type: type of integration to use for this cell's dynamics;
124
128
current supported forms include "euler" (Euler/RK-1 integration)
125
129
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
@@ -131,7 +135,7 @@ class FitzhughNagumoCell(JaxComponent):
131
135
132
136
# Define Functions
133
137
def __init__ (self , name , n_units , tau_m = 1. , resist_m = 1. , tau_w = 12.5 , alpha = 0.7 ,
134
- beta = 0.8 , gamma = 3. , v_thr = 1.07 , v0 = 0. , w0 = 0. ,
138
+ beta = 0.8 , gamma = 3. , v0 = 0. , w0 = 0. , v_thr = 1.07 , spike_reset = False ,
135
139
integration_type = "euler" , ** kwargs ):
136
140
super ().__init__ (name , ** kwargs )
137
141
@@ -150,6 +154,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
150
154
self .v0 = v0 ## initial membrane potential/voltage condition
151
155
self .w0 = w0 ## initial w-parameter condition
152
156
self .v_thr = v_thr
157
+ self .spike_reset = spike_reset
153
158
154
159
## Layer Size Setup
155
160
self .batch_size = 1
@@ -164,10 +169,13 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
164
169
self .tols = Compartment (restVals ) ## time-of-last-spike
165
170
166
171
@staticmethod
167
- def _advance_state (t , dt , tau_m , R_m , tau_w , v_thr , alpha , beta , gamma ,
168
- intgFlag , j , v , w , tols ):
172
+ def _advance_state (t , dt , tau_m , R_m , tau_w , v_thr , spike_reset , v0 , w0 , alpha ,
173
+ beta , gamma , intgFlag , j , v , w , tols ):
169
174
v , w , s = _run_cell (dt , j * R_m , v , w , v_thr , tau_m , tau_w , alpha , beta ,
170
175
gamma , intgFlag )
176
+ if spike_reset : ## if spike-reset used, variables snapped back to initial conditions
177
+ v = v * (1. - s ) + s * v0
178
+ w = w * (1. - s ) + s * w0
171
179
tols = _update_times (t , s , tols )
172
180
return j , v , w , s , tols
173
181
@@ -220,6 +228,8 @@ def help(cls): ## component help function
220
228
"resist_m" : "Membrane resistance value" ,
221
229
"tau_w" : "Recovery variable time constant" ,
222
230
"v_thr" : "Base voltage threshold value" ,
231
+ "spike_reset" : "Should voltage/recover be snapped to initial "
232
+ "condition(s) if spike emitted?" ,
223
233
"alpha" : "Dimensionless recovery variable shift factor `a" ,
224
234
"beta" : "Dimensionless recovery variable scale factor `b`" ,
225
235
"gamma" : "Power-term divisor constant" ,
0 commit comments