diff --git a/ngclearn/components/base_monitor.py b/ngclearn/components/base_monitor.py index 6efb367cc..00c338b2a 100644 --- a/ngclearn/components/base_monitor.py +++ b/ngclearn/components/base_monitor.py @@ -42,8 +42,8 @@ class Base_Monitor(Component): default_window_length: The default window length. """ + auto_resolve = False - _singleton = None # Only one Monitor @staticmethod def build_advance(compartments): @@ -62,14 +62,19 @@ def build_advance(compartments): "ngclearn.components.lava (If using lava)") @staticmethod - def build_reset(compartments): + def build_reset(component): """ A method to build the method to reset the stored values. Args: - compartments: A list of compartments to reset + component: The component to resolve - Returns: The method to reset the stored values. + Returns: the reset resolver """ + output_compartments = [] + compartments = [] + for comp in component.compartments: + output_compartments.append(comp.split("/")[-1] + "*store") + compartments.append(comp.split("/")[-1]) @staticmethod def _reset(**kwargs): @@ -79,13 +84,22 @@ def _reset(**kwargs): return_vals.append(np.zeros(current_store.shape)) return return_vals if len(compartments) > 1 else return_vals[0] - return _reset + # pure func, output compartments, args, params, input compartments + return _reset, output_compartments, [], [], output_compartments + + @staticmethod + def build_advance_state(component): + output_compartments = [] + compartments = [] + for comp in component.compartments: + output_compartments.append(comp.split("/")[-1] + "*store") + compartments.append(comp.split("/")[-1]) + + _advance = component.build_advance(compartments) + + return _advance, output_compartments, [], [], compartments + output_compartments def __init__(self, name, default_window_length=100, **kwargs): - if Base_Monitor._singleton is not None: - critical("Only one monitor can be built") - else: - Base_Monitor._singleton = True super().__init__(name, **kwargs) self.store = {} self.compartments = [] @@ -127,7 +141,7 @@ def watch(self, compartment, window_length): setattr(self, store_comp_key, new_comp_store) self.compartments.append(new_comp.path) self._sources.append(compartment) - self._update_resolver() + # self._update_resolver() def halt(self, compartment): """ @@ -157,29 +171,29 @@ def halt_all(self): for compartment in self._sources: self.halt(compartment) - def _update_resolver(self): - output_compartments = [] - compartments = [] - for comp in self.compartments: - output_compartments.append(comp.split("/")[-1] + "*store") - compartments.append(comp.split("/")[-1]) - - args = [] - parameters = [] - - add_component_resolver(self.__class__.__name__, "advance_state", - (self.build_advance(compartments), - output_compartments)) - add_resolver_meta(self.__class__.__name__, "advance_state", - (args, parameters, - compartments + [o for o in output_compartments], - False)) - - add_component_resolver(self.__class__.__name__, "reset", ( - self.build_reset(compartments), output_compartments)) - add_resolver_meta(self.__class__.__name__, "reset", - (args, parameters, [o for o in output_compartments], - False)) + # def _update_resolver(self): + # output_compartments = [] + # compartments = [] + # for comp in self.compartments: + # output_compartments.append(comp.split("/")[-1] + "*store") + # compartments.append(comp.split("/")[-1]) + # + # args = [] + # parameters = [] + # + # add_component_resolver(self.__class__.__name__, "advance_state", + # (self.build_advance(compartments), + # output_compartments)) + # add_resolver_meta(self.__class__.__name__, "advance_state", + # (args, parameters, + # compartments + [o for o in output_compartments], + # False)) + + # add_component_resolver(self.__class__.__name__, "reset", ( + # self.build_reset(compartments), output_compartments)) + # add_resolver_meta(self.__class__.__name__, "reset", + # (args, parameters, [o for o in output_compartments], + # False)) def _add_path(self, path): _path = path.split("/")[1:] diff --git a/ngclearn/components/lava/monitor.py b/ngclearn/components/lava/monitor.py index de939fb06..aaabf8f84 100644 --- a/ngclearn/components/lava/monitor.py +++ b/ngclearn/components/lava/monitor.py @@ -5,6 +5,8 @@ class Monitor(Base_Monitor): """ A numpy implementation of `Base_Monitor`. Designed to be used with all lava compatible ngclearn components """ + auto_resolve = False + @staticmethod def build_advance(compartments): @@ -20,3 +22,11 @@ def _advance(**kwargs): return return_vals if len(compartments) > 1 else return_vals[0] return _advance + + @staticmethod + def build_advance_state(component): + return super().build_advance_state(component) + + @staticmethod + def build_reset(component): + return super().build_reset(component) diff --git a/ngclearn/components/monitor.py b/ngclearn/components/monitor.py index 6eada571d..4d61c3438 100644 --- a/ngclearn/components/monitor.py +++ b/ngclearn/components/monitor.py @@ -5,6 +5,8 @@ class Monitor(Base_Monitor): A jax implementation of `Base_Monitor`. Designed to be used with all non-lava ngclearn components """ + auto_resolve = False + @staticmethod def build_advance(compartments): @staticmethod @@ -18,3 +20,11 @@ def _advance(**kwargs): return_vals.append(current_store) return return_vals if len(compartments) > 1 else return_vals[0] return _advance + + @staticmethod + def build_advance_state(component): + return super().build_advance_state(component) + + @staticmethod + def build_reset(component): + return super().build_reset(component)