@@ -151,7 +151,8 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
151151 @deprecate_args (thr_jitter = None )
152152 def __init__ (self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. ,
153153 v_reset = - 60. , v_decay = 1. , tau_theta = 1e7 , theta_plus = 0.05 ,
154- refract_time = 5. , one_spike = False , integration_type = "euler" , ** kwargs ):
154+ refract_time = 5. , one_spike = False , integration_type = "euler" ,
155+ surrgoate_type = "straight_through" , ** kwargs ):
155156 super ().__init__ (name , ** kwargs )
156157
157158 ## Integration properties
@@ -179,14 +180,13 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
179180 self .n_units = n_units
180181
181182 ## set up surrogate function for spike emission
182- surrgoate_type = "secant_lif"
183183 if surrgoate_type == "secant_lif" :
184184 self .spike_fx , self .d_spike_fx = secant_lif_estimator ()
185185 elif surrgoate_type == "arctan" :
186186 self .spike_fx , self .d_spike_fx = arctan_estimator ()
187187 elif surrgoate_type == "triangular" :
188188 self .spike_fx , self .d_spike_fx = triangular_estimator ()
189- else : ## default is the straight-through estimator (STE)
189+ else : ## default: straight_through
190190 self .spike_fx , self .d_spike_fx = straight_through_estimator ()
191191
192192
0 commit comments