From 9afaadfe65df2bd2597fb49d75666efc58eb3c31 Mon Sep 17 00:00:00 2001 From: ago109 Date: Thu, 25 Jul 2024 11:55:08 -0400 Subject: [PATCH] fixed validation fun in bern/poiss --- ngclearn/components/input_encoders/bernoulliCell.py | 7 +++++-- ngclearn/components/input_encoders/poissonCell.py | 9 ++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index ea3b26cc8..b061b431f 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -101,9 +101,12 @@ def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs): self.outputs = Compartment(restVals, display_name="Spikes") # output compartment self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike - def validate(self, dt, **validation_kwargs): - ## check for unstable combinations of dt and target-frequency meta-params + def validate(self, dt=None, **validation_kwargs): valid = super().validate(**validation_kwargs) + if dt is None: + warn(f"{self.name} requires a validation kwarg of `dt`") + return False + ## check for unstable combinations of dt and target-frequency meta-params events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability if events_per_timestep > 1.: valid = False diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index 36b968a82..8c5fba991 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -56,10 +56,13 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1, random.uniform(subkey, (self.batch_size, self.n_units), minval=0., maxval=1.)) - def validate(self, dt, **validation_kwargs): - ## check for unstable combinations of dt and target-frequency meta-params + def validate(self, dt=None, **validation_kwargs): valid = super().validate(**validation_kwargs) - events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability + if dt is None: + warn(f"{self.name} requires a validation kwarg of `dt`") + return False + ## check for unstable combinations of dt and target-frequency meta-params + events_per_timestep = (dt / 1000.) * self.target_freq ## compute scaled probability if events_per_timestep > 1.: valid = False warn(