Skip to content

Commit

Permalink
Added auto resolving for monitors
Browse files Browse the repository at this point in the history
  • Loading branch information
willgebhardt committed Jul 23, 2024
1 parent 19bf502 commit 6bb493b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 33 deletions.
80 changes: 47 additions & 33 deletions ngclearn/components/base_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class Base_Monitor(Component):
default_window_length: The default window length.
"""
auto_resolve = False

_singleton = None # Only one Monitor

@staticmethod
def build_advance(compartments):
Expand All @@ -62,14 +62,19 @@ def build_advance(compartments):
"ngclearn.components.lava (If using lava)")

@staticmethod
def build_reset(compartments):
def build_reset(component):
"""
A method to build the method to reset the stored values.
Args:
compartments: A list of compartments to reset
component: The component to resolve
Returns: The method to reset the stored values.
Returns: the reset resolver
"""
output_compartments = []
compartments = []
for comp in component.compartments:
output_compartments.append(comp.split("/")[-1] + "*store")
compartments.append(comp.split("/")[-1])

@staticmethod
def _reset(**kwargs):
Expand All @@ -79,13 +84,22 @@ def _reset(**kwargs):
return_vals.append(np.zeros(current_store.shape))
return return_vals if len(compartments) > 1 else return_vals[0]

return _reset
# pure func, output compartments, args, params, input compartments
return _reset, output_compartments, [], [], output_compartments

@staticmethod
def build_advance_state(component):
output_compartments = []
compartments = []
for comp in component.compartments:
output_compartments.append(comp.split("/")[-1] + "*store")
compartments.append(comp.split("/")[-1])

_advance = component.build_advance(compartments)

return _advance, output_compartments, [], [], compartments + output_compartments

def __init__(self, name, default_window_length=100, **kwargs):
if Base_Monitor._singleton is not None:
critical("Only one monitor can be built")
else:
Base_Monitor._singleton = True
super().__init__(name, **kwargs)
self.store = {}
self.compartments = []
Expand Down Expand Up @@ -127,7 +141,7 @@ def watch(self, compartment, window_length):
setattr(self, store_comp_key, new_comp_store)
self.compartments.append(new_comp.path)
self._sources.append(compartment)
self._update_resolver()
# self._update_resolver()

def halt(self, compartment):
"""
Expand Down Expand Up @@ -157,29 +171,29 @@ def halt_all(self):
for compartment in self._sources:
self.halt(compartment)

def _update_resolver(self):
output_compartments = []
compartments = []
for comp in self.compartments:
output_compartments.append(comp.split("/")[-1] + "*store")
compartments.append(comp.split("/")[-1])

args = []
parameters = []

add_component_resolver(self.__class__.__name__, "advance_state",
(self.build_advance(compartments),
output_compartments))
add_resolver_meta(self.__class__.__name__, "advance_state",
(args, parameters,
compartments + [o for o in output_compartments],
False))

add_component_resolver(self.__class__.__name__, "reset", (
self.build_reset(compartments), output_compartments))
add_resolver_meta(self.__class__.__name__, "reset",
(args, parameters, [o for o in output_compartments],
False))
# def _update_resolver(self):
# output_compartments = []
# compartments = []
# for comp in self.compartments:
# output_compartments.append(comp.split("/")[-1] + "*store")
# compartments.append(comp.split("/")[-1])
#
# args = []
# parameters = []
#
# add_component_resolver(self.__class__.__name__, "advance_state",
# (self.build_advance(compartments),
# output_compartments))
# add_resolver_meta(self.__class__.__name__, "advance_state",
# (args, parameters,
# compartments + [o for o in output_compartments],
# False))

# add_component_resolver(self.__class__.__name__, "reset", (
# self.build_reset(compartments), output_compartments))
# add_resolver_meta(self.__class__.__name__, "reset",
# (args, parameters, [o for o in output_compartments],
# False))

def _add_path(self, path):
_path = path.split("/")[1:]
Expand Down
10 changes: 10 additions & 0 deletions ngclearn/components/lava/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ class Monitor(Base_Monitor):
"""
A numpy implementation of `Base_Monitor`. Designed to be used with all lava compatible ngclearn components
"""
auto_resolve = False


@staticmethod
def build_advance(compartments):
Expand All @@ -20,3 +22,11 @@ def _advance(**kwargs):
return return_vals if len(compartments) > 1 else return_vals[0]

return _advance

@staticmethod
def build_advance_state(component):
return super().build_advance_state(component)

@staticmethod
def build_reset(component):
return super().build_reset(component)
10 changes: 10 additions & 0 deletions ngclearn/components/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ class Monitor(Base_Monitor):
A jax implementation of `Base_Monitor`. Designed to be used with all
non-lava ngclearn components
"""
auto_resolve = False

@staticmethod
def build_advance(compartments):
@staticmethod
Expand All @@ -18,3 +20,11 @@ def _advance(**kwargs):
return_vals.append(current_store)
return return_vals if len(compartments) > 1 else return_vals[0]
return _advance

@staticmethod
def build_advance_state(component):
return super().build_advance_state(component)

@staticmethod
def build_reset(component):
return super().build_reset(component)

0 comments on commit 6bb493b

Please sign in to comment.