diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index ca345fb79..d871add89 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -119,6 +119,7 @@ def _create_node_with_operator( from __future__ import annotations +from collections.abc import ItemsView, KeysView, ValuesView import operator # Operator overloading for computation nodes. from weakref import WeakSet # Manages relationships between nodes without # creating circular dependencies. @@ -152,19 +153,19 @@ def _create_node_with_operator( class DeepTrackDataObject: """Basic data container for DeepTrack2. - `DeepTrackDataObject` is a simple data container to store some data and + `DeepTrackDataObject` is a simple data container to store some data and track its validity. Attributes ---------- - data: Any - The stored data. Default is `None`. - valid: bool - A flag indicating whether the stored data is valid. Default is `False`. + _data: Any + The stored data. Defaults to `None`. + _valid: bool + Flag indicating whether the stored data is valid. Defaults to `False`. Methods ------- - `store(data: Any) -> None` + `store(data) -> None` Store data in the container and mark it as valid. `current_value() -> Any` Return the currently stored data. @@ -174,53 +175,60 @@ class DeepTrackDataObject: Mark the data as invalid. `validate() -> None` Mark the data as valid. + `__repr__() -> str` + Return the string representation of the object. Example ------- >>> import deeptrack as dt Create a `DeepTrackDataObject`: - >>> data_obj = dt.DeepTrackDataObject() + >>> data_obj Store a value in this container: - >>> data_obj.store(42) + >>> data_obj + DeepTrackDataObject(data=42, valid=True) + + Access the currently stored value: >>> data_obj.current_value() 42 Check if the stored data is valid: - >>> data_obj.is_valid() True Invalidate the stored data: - >>> data_obj.invalidate() + >>> data_obj + DeepTrackDataObject(data=42, valid=False) + >>> data_obj.is_valid() False - Validate the data again to restore its valid status: - + Validate the data to restore its valid status: >>> data_obj.validate() + >>> data_obj + DeepTrackDataObject(data=42, valid=True) + >>> data_obj.is_valid() True """ - data: Any - valid: bool + _data: Any + _valid: bool def __init__(self: DeepTrackDataObject): """Initialize the container without data. - It sets the `data` and `valid` attributes to their default values - `None` and `False`. + Initializes `_data` to `None` and `_valid` to `False`. """ - self.data = None - self.valid = False + self._data = None + self._valid = False def store( self: DeepTrackDataObject, @@ -235,8 +243,8 @@ def store( """ - self.data = data - self.valid = True + self._data = data + self._valid = True def current_value(self: DeepTrackDataObject) -> Any: """Retrieve the stored data. @@ -248,7 +256,7 @@ def current_value(self: DeepTrackDataObject) -> Any: """ - return self.data + return self._data def is_valid(self: DeepTrackDataObject) -> bool: """Return whether the stored data is valid. @@ -257,183 +265,253 @@ def is_valid(self: DeepTrackDataObject) -> bool: ------- bool `True` if the data is valid, `False` otherwise. - + """ - return self.valid + return self._valid def invalidate(self: DeepTrackDataObject) -> None: """Mark the stored data as invalid.""" - self.valid = False + self._valid = False def validate(self: DeepTrackDataObject) -> None: """Mark the stored data as valid.""" - self.valid = True + self._valid = True + + def __repr__(self: DeepTrackDataObject) -> str: + """Return the string representation of the object. + + Provides a concise representation of the data object, including the + stored data and its validity flag. It is useful for debugging and + logging purposes. + + Returns + ------- + str + A string in the format: + "DeepTrackDataObject(data=, valid=)". + + """ + + return ( + f"{self.__class__.__name__}" + f"(data={self._data!r}, valid={self._valid})" + ) class DeepTrackDataDict: - """Stores multiple data objects indexed by tuples of integers (_ID). + """Store multiple data objects indexed by tuples of integers (_ID). + + `DeepTrackDataDict` can store multiple `DeepTrackDataObject` instances, + each associated with a unique tuple of integers (its `_ID`). + + **Use of _IDs** + + The default `_ID` is an empty tuple, `_ID = ()`. - `DeepTrackDataDict` can store multiple `DeepTrackDataObject` instances, - each associated with a unique tuple of integers (its _ID). This is - particularly useful to handle sequences of data or nested structures. + Once the first entry is created, all `_ID`s must match the set key length. - The default _ID is an empty tuple, `()`. Once the first entry is created, - all _IDs must match the established key length: - - If an _ID longer than the set length is requested, it is trimmed. - - If an _ID shorter than the set length is requested, a dictionary slice - containing all matching entries is returned. + When retrieving the data associated to an `_ID`: + - If an `_ID` longer than the set key length is requested, it is trimmed. + - If an `_ID` shorter than the set key length is requested, a dictionary + slice containing all matching entries is returned. + + NOTE: The `_ID`s are specifically used in the `Repeat` feature to allow it + to return different values without changing the input. Attributes ---------- keylength: int or None - The length of the _IDs currently stored. Set when the first entry is - created. If `None`, no entries have been created yet, and any _ID - length is valid. + Read-only property exposing the internal variable with the length of + the `_ID`s set when the first entry is created. If `None`, no entries + have been created, and any `_ID` length is valid. dict: dict[tuple[int, ...], DeepTrackDataObject] or {} - A dictionary mapping tuples of integers (_IDs) to + Read-only property exposing the internal dictionary of stored data, + `_dict`. This is a dictionary mapping tuples of integers (`_ID`s) to `DeepTrackDataObject` instances. Methods ------- + `create_index(_ID) -> None` + Create an entry for the given `_ID` if it does not exist. `invalidate() -> None` Mark all stored data objects as invalid. `validate() -> None` Mark all stored data objects as valid. - `valid_index(_ID: tuple[int, ...]) -> bool` + `valid_index(_ID) -> bool` Check if the given _ID is valid for the current configuration. - `create_index(_ID: tuple[int, ...] = ()) -> None` - Create an entry for the given _ID if it does not exist. - `__getitem__(_ID: tuple[int, ...]) -> DeepTrackDataObject or dict[tuple[int, ...], DeepTrackDataObject]` - Retrieve data associated with the _ID. Can return a - `DeepTrackDataObject` or a dict of matching entries if `_ID` is shorter - than `keylength`. - `__contains__(_ID: tuple[int, ...]) -> bool` - Check whether the given _ID exists in the dictionary. + `__getitem__(_ID) -> DeepTrackDataObject or dict[_ID, DeepTrackDataObject]` + Retrieve data associated with the `_ID`. Can return a + `DeepTrackDataObject`, or a dict of `DeepTrackDataObject`s if `_ID` is + shorter than `keylength`. + `__contains__(_ID) -> bool` + Check whether the given `_ID` exists in the dictionary. + `__len__() -> int` + Return the number of stored entries. + `__iter__() -> Iterator` + Iterate over the keys of the dictionary. + `items() -> ItemsView[tuple[int, ...], DeepTrackDataObject]` + Return a view of the dictionary’s (key, value) pairs. + `keys() -> KeysView[tuple[int, ...]]` + Return a view of the dictionary’s keys. + `values() -> ValuesView[DeepTrackDataObject]` + Return a view of the dictionary’s values. + `__repr__() -> str` + Return a string representation of the data dictionary. Example ------- >>> import deeptrack as dt Create a structure to store multiple, indexed instances of data: - >>> data_dict = dt.DeepTrackDataDict() + >>> data_dict + DeepTrackDataDict(0 entries, keylength=None) Create the entries: - >>> data_dict.create_index((0, 0)) >>> data_dict.create_index((0, 1)) >>> data_dict.create_index((1, 0)) >>> data_dict.create_index((1, 1)) + >>> data_dict + DeepTrackDataDict(4 entries, keylength=2) - Store the values associated with each _ID: - + Store the values associated with each `_ID`: >>> data_dict[(0, 0)].store("Data at (0, 0)") >>> data_dict[(0, 1)].store("Data at (0, 1)") >>> data_dict[(1, 0)].store("Data at (1, 0)") >>> data_dict[(1, 1)].store("Data at (1, 1)") + >>> data_dict + DeepTrackDataDict(4 entries, keylength=2) - Retrieve values based on their _IDs: + Retrieve values based on their `_ID`s: + >>> data_dict[(0, 0)] + DeepTrackDataObject(data='Data at (0, 0)', valid=True) >>> data_dict[(0, 0)].current_value() - Data at (0, 0) + 'Data at (0, 0)' + + >>> data_dict[(1, 1)] + DeepTrackDataObject(data='Data at (1, 1)', valid=True) >>> data_dict[(1, 1)].current_value() - Data at (1, 1) + 'Data at (1, 1)' - If requesting a shorter _ID, it returns all matching nested entries: - + If requesting a shorter `_ID`, it returns all matching nested entries: >>> data_dict[(0,)] - { - (0, 0): , - (0, 1): , - } + {(0, 0): DeepTrackDataObject(data='Data at (0, 0)', valid=True), + (0, 1): DeepTrackDataObject(data='Data at (0, 1)', valid=True)} Validate and invalidate all entries at once: - >>> data_dict.invalidate() >>> data_dict[(0, 0)].is_valid() False + >>> data_dict[(1, 1)].is_valid() False >>> data_dict.validate() >>> data_dict[(0, 0)].is_valid() True + >>> data_dict[(1, 1)].is_valid() True Invalidate and validate a single entry: - >>> data_dict[(0, 1)].invalidate() >>> data_dict[(0, 1)].is_valid() False + >>> data_dict[(0, 1)].validate() >>> data_dict[(0, 1)].is_valid() True - Check if a given _ID exists: - + Check if a given `_ID` exists: >>> (1, 0) in data_dict True + >>> (2, 2) in data_dict False Iterate over all entries: - - >>> for key, value in data_dict.dict.items(): + >>> for key, value in data_dict.items(): ... print(key, value.current_value()) - (0, 0) Data at (0, 0) - (0, 1) Data at (0, 1) - (1, 0) Data at (1, 0) - (1, 1) Data at (1, 1) - - Check if an _ID is valid according to current keylength: - + (0, 0) DeepTrackDataObject(data='Data at (0, 0)', valid=True) + (0, 1) DeepTrackDataObject(data='Data at (0, 1)', valid=True) + (1, 0) DeepTrackDataObject(data='Data at (1, 0)', valid=True) + (1, 1) DeepTrackDataObject(data='Data at (1, 1)', valid=True) + + >>> for key in data_dict.keys(): + ... print(key) + (0, 0) + (0, 1) + (1, 0) + (1, 1) + + >>> for value in data_dict.values(): + ... print(value) + DeepTrackDataObject(data='Data at (0, 0)', valid=True) + DeepTrackDataObject(data='Data at (0, 1)', valid=True) + DeepTrackDataObject(data='Data at (1, 0)', valid=True) + DeepTrackDataObject(data='Data at (1, 1)', valid=True) + + Check if an `_ID` is valid according to current keylength: >>> data_dict.valid_index((0, 1)) True - >>> data_dict.valid_index((0,)) # Shorter than keylength after creation + + >>> data_dict.valid_index((0,)) # Shorter than keylength + False + + >>> data_dict.valid_index((0, 1, 2)) # Longer than keylength False + >>> data_dict.valid_index((2, 2)) # Valid length, even if not created yet True """ - keylength: int | None - dict: dict[tuple[int, ...], DeepTrackDataObject] + _keylength: int | None + _dict: dict[tuple[int, ...], DeepTrackDataObject] def __init__(self: DeepTrackDataDict): """Initialize the data dictionary. - It initializes `keylength` to `None` and `dict` to an empty dictionary, - indicating no data objects are currently stored. - + Initializes `keylength` to `None` and `dict` to an empty dictionary, + indicating no `DeepTrackDataObject`s are currently stored. + """ - self.keylength = None - self.dict = {} + self._keylength = None + self._dict = {} def invalidate(self: DeepTrackDataDict) -> None: """Mark all stored data objects as invalid. - It calls `invalidate()` on every `DeepTrackDataObject` in the - dictionary. + Calls `invalidate()` on every `DeepTrackDataObject` in the dictionary. + + NOTE: Currently, it invalidates the data objects stored at all `_ID`s. + TODO: Add optional argument `_ID: tuple[int, ...] ()` and permit + invalidation of only specific `_ID`s. """ - for dataobject in self.dict.values(): + for dataobject in self._dict.values(): dataobject.invalidate() def validate(self: DeepTrackDataDict) -> None: """Mark all stored data objects as valid. - It calls `validate()` on every `DeepTrackDataObject` in the dictionary. + Calls `validate()` on every `DeepTrackDataObject` in the dictionary. + + NOTE: Currently, it validates the data objects stored at all `_ID`s. + TODO: Add optional argument `_ID: tuple[int, ...] ()` and permit + validation of only specific `_ID`s. """ - for dataobject in self.dict.values(): + for dataobject in self._dict.values(): dataobject.validate() def valid_index( @@ -442,14 +520,14 @@ def valid_index( ) -> bool: """Check if a given _ID is valid for this data dictionary. - If `keylength` is `None`, any tuple `_ID` is considered valid since no - entries have been created yet. + If `keylength` is `None`, any tuple `_ID` is considered valid (since + no entries have been created yet). If `_ID` already exists in `dict`, it is automatically valid. - + Otherwise, `_ID` must have the same length as `keylength` to be considered valid. - + Parameters ---------- _ID: tuple[int, ...] @@ -458,14 +536,14 @@ def valid_index( Returns ------- bool - `True` if the _ID is valid given the current configuration, `False` - otherwise. + `True` if the `_ID` is valid given the current configuration, + `False` otherwise. Raises ------ AssertionError If `_ID` is not a tuple of integers. - + """ # Ensure _ID is a tuple of integers. @@ -478,16 +556,16 @@ def valid_index( ) # If keylength has not yet been set, all indexes are valid. - if self.keylength is None: + if self._keylength is None: return True # If index is already stored, always valid. - if _ID in self.dict: + if _ID in self._dict: return True # Otherwise, the _ID length must match the established keylength # for _ID to be valid. - return len(_ID) == self.keylength + return len(_ID) == self._keylength def create_index( self: DeepTrackDataDict, @@ -499,23 +577,24 @@ def create_index( `DeepTrackDataObject`. If `_ID` is already in `dict`, no new entry is created. - - If `keylength` is `None`, it is set to the length of `_ID`. Once - established, all subsequently created _IDs must have this same length. + + If `keylength` is `None`, it is set to the length of `_ID`. Once + established, all subsequently created `_ID`s must have this same + length. Parameters ---------- _ID: tuple[int, ...], optional - A tuple of integers representing the _ID for the data entry. - Default is `()`, which represents a root-level data entry with no + A tuple of integers representing the _ID for the data entry. + Defaults to `()`, which represents a root-level data entry with no nesting. - + Raises ------ AssertionError - If `_ID` is not a tuple of integers. - If `_ID` is not valid for the current configuration. - + """ # Check if the given _ID is valid. @@ -525,15 +604,15 @@ def create_index( ) # If `_ID` already exists, do nothing. - if _ID in self.dict: + if _ID in self._dict: return # Create a new DeepTrackDataObject for this _ID. - self.dict[_ID] = DeepTrackDataObject() + self._dict[_ID] = DeepTrackDataObject() - # If `keylength` is not set, initialize it with current _IDs length. - if self.keylength is None: - self.keylength = len(_ID) + # If `_keylength` is not set, initialize it with current _IDs length. + if self._keylength is None: + self._keylength = len(_ID) def __getitem__( self: DeepTrackDataDict, @@ -544,25 +623,26 @@ def __getitem__( Parameters ---------- _ID: tuple[int, ...] - The _ID for the requested data. + The `_ID` for the requested data. Returns ------- - DeepTrackDataObject or Dict[tuple[int, ...], DeepTrackDataObject] - If `_ID` matches `keylength`, it returns the corresponding + DeepTrackDataObject or dict[tuple[int, ...], DeepTrackDataObject] + If `_ID` matches `keylength`, it returns the corresponding `DeepTrackDataObject`. - If `_ID` is longer than `keylength`, the request is trimmed to + If `_ID` is longer than `keylength`, the request is trimmed to match `keylength` and it returns the corresponding `DeepTrackDataObject`. If `_ID` is shorter than `keylength`, it returns a dict of all - entries whose _IDs match the given `_ID` prefix. + entries whose `_ID`s match the given `_ID` prefix. Raises ------ AssertionError If `_ID` is not a tuple of integers. KeyError - If the dictionary is empty (`keylength` is `None`). + If the dictionary is empty (`keylength` is `None`), or if the + requested `_ID` is not in the dictionary. """ @@ -575,20 +655,25 @@ def __getitem__( f"Got a tuple of types: {[type(i).__name__ for i in _ID]}." ) - if self.keylength is None: + if self._keylength is None: raise KeyError("Attempting to index an empty dict.") # If _ID matches keylength, return corresponding DeepTrackDataObject. - if len(_ID) == self.keylength: - return self.dict[_ID] + if len(_ID) == self._keylength: + if _ID not in self._dict: + raise KeyError( + f"The _ID {_ID} does not exist in this DeepTrackDataDict. " + f"Available keys: {list(self._dict.keys())}" + ) + return self._dict[_ID] # If _ID longer than keylength, trim the requested _ID # and return corresponding DeepTrackDataObject. - if len(_ID) > self.keylength: - return self[_ID[: self.keylength]] + if len(_ID) > self._keylength: + return self[_ID[: self._keylength]] # If _ID shorter than keylength, return a slice of all matching items. - return {k: v for k, v in self.dict.items() if k[: len(_ID)] == _ID} + return {k: v for k, v in self._dict.items() if k[: len(_ID)] == _ID} def __contains__( self: DeepTrackDataDict, @@ -599,38 +684,140 @@ def __contains__( Parameters ---------- _ID: tuple[int, ...] - The _ID to check. + The `_ID` to check. Returns ------- bool - `True` if the _ID exists, `False` otherwise. + `True` if `_ID` exists, `False` otherwise. - Raises - ------ - AssertionError - If `_ID` is not a tuple of integers. + """ + + return _ID in self._dict + + def __len__(self: DeepTrackDataDict) -> int: + """Return the number of stored entries. + + Returns + ------- + int + The number of `_ID` entries in the dictionary. """ - # Ensure _ID is a tuple of integers. - assert isinstance(_ID, tuple), ( - f"Data index {_ID} is not a tuple. Got: {type(_ID).__name__}." - ) - assert all(isinstance(i, int) for i in _ID), ( - f"Data index {_ID} is not a tuple of integers. " - f"Got a tuple of types: {[type(i).__name__ for i in _ID]}." + return len(self._dict) + + def __iter__(self: DeepTrackDataDict) -> Iterator[tuple[int, ...]]: + """Iterate over the keys of the dictionary. + + Returns + ------- + Iterator[tuple[int, ...]] + An iterator over the dictionary's keys. + + """ + + return iter(self._dict) + + def items( + self: DeepTrackDataDict, + ) -> ItemsView[tuple[int, ...], DeepTrackDataObject]: + """Return a view of the dictionary’s (key, value) pairs. + + Returns + ------- + ItemsView[tuple[int, ...], DeepTrackDataObject] + A dynamic view of the internal dictionary’s entries. + + """ + + return self._dict.items() + + def keys(self: DeepTrackDataDict) -> KeysView[tuple[int, ...]]: + """Return a view of the dictionary’s keys. + + Returns + ------- + KeysView[tuple[int, ...]] + A dynamic view of the internal dictionary’s keys. + + """ + + return self._dict.keys() + + def values(self: DeepTrackDataDict) -> ValuesView[DeepTrackDataObject]: + """Return a view of the dictionary’s values. + + Returns + ------- + ValuesView[DeepTrackDataObject] + A dynamic view of the internal dictionary’s values. + + """ + return self._dict.values() + + def __repr__(self: DeepTrackDataDict) -> str: + """Return a string representation of the data dictionary. + + Provides a concise summary of the current `DeepTrackDataDict` instance, + including the number of stored entries and the current `keylength`. It + is useful for debugging and logging. + + Returns + ------- + str + A string in the format: + "DeepTrackDataDict( entries, keylength=)". + + """ + + return ( + f"{self.__class__.__name__}" + f"({len(self)} entries, keylength={self.keylength})" ) - return _ID in self.dict + @property + def keylength(self: DeepTrackDataDict) -> int | None: + """Access the internal keylength (read-only). + + This property exploses the internal `_keylength` attribute as a public + read-only interface. + + Returns + ------- + int or None + The key length. + + """ + + return self._keylength + + @property + def dict( + self: DeepTrackDataDict, + ) -> dict[tuple[int, ...], DeepTrackDataObject]: + """Access the internal data dictionary (read-only). + + This property exposes the internal `_dict` attribute as a public + read-only interface. It allows access to all stored data objects + indexed by their `_ID`. + + Returns + ------- + dict[tuple[int, ...], DeepTrackDataObject] + The mapping of `_ID`s to `DeepTrackDataObject` instances. + + """ + + return self._dict class DeepTrackNode: """Node in a DeepTrack2 computation graph, supporting operator overloading. - `DeepTrackNode` represents a node within a DeepTrack2 computation graph. + `DeepTrackNode` represents a node within a DeepTrack2 computation graph. Each node can store data and compute new values based on its dependencies. - The value of a node is computed by calling its `action` method. + The value of a node is computed by calling its `action`. `DeepTrackNode` supports operator overloading, enabling intuitive construction of computation graphs using standard Python operators. @@ -641,78 +828,96 @@ class DeepTrackNode: Parameters ---------- action: Callable or Any, optional - Action to compute this node's value. If not provided, uses a no-op + Action to compute this node's value. If not provided, uses a no-op action (lambda: None). - **kwargs: dict[str, Any] + node_name: str or None, optional + Optional name assigned to the node. Defaults to `None`. + **kwargs: Any Additional arguments for subclasses or extended functionality. Attributes ---------- + node_name: str or None + Optional name assigned to the node. Defaults to `None`. data: DeepTrackDataDict Dictionary-like object for storing data, indexed by tuples of integers. children: WeakSet[DeepTrackNode] - Nodes that depend on this node (its children, grandchildren, etc.). + Read-only property exposing the internal weak set `_children` + containing the nodes that depend on this node (its children). This is a weakref.WeakSet, so references are weak and do not prevent garbage collection of nodes that are no longer used. dependencies: WeakSet[DeepTrackNode] - Nodes on which this node depends (its parents, grandparents, etc.). + Read-only property exposing the internal weak set `_dependencies` + containign the nodes on which this node depends (its parents). This is a weakref.WeakSet, for efficient memory management. - _action: Callable + _action: Callable[..., Any] The function or lambda-function to compute the node value. _accepts_ID: bool Whether `action` accepts an input _ID. - _all_children: set[DeepTrackNode] + _all_children: WeakSet[DeepTrackNode] All nodes in the subtree rooted at the node, including the node itself. + This is a weakref.WeakSet, for efficient memory management. + _all_dependencies: WeakSet[DeepTrackNode] + All the dependencies for this node, including the node itself. + This is a weakref.WeakSet, for efficient memory management. _citations: list[str] Citations associated with this node. - + Methods ------- `action: property` Get or set the computation function for the node (stored as `_action`). - `add_child(child: DeepTrackNode) -> DeepTrackNode` + `add_child(child) -> DeepTrackNode` Add a child node that depends on this node. Also add the dependency on this node in the child node. - `add_dependency(parent: DeepTrackNode) -> DeepTrackNode` + `add_dependency(parent) -> DeepTrackNode` Add a dependency, making this node depend on the parent node. Also set this node as a child of the parent node. - `store(data: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode` + `store(data, _ID) -> DeepTrackNode` Store computed data for the given `_ID`. - `is_valid(_ID: tuple[int, ...] = ()) -> bool` + `is_valid(_ID) -> bool` Check whether the data for the given `_ID` is valid. - `valid_index(_ID: tuple[int, ...]) -> bool` + `valid_index(_ID) -> bool` Check whether the given `_ID` is valid for this node. - `invalidate(_ID: tuple[int, ...] = ()) -> DeepTrackNode` + `invalidate(_ID) -> DeepTrackNode` Invalidate the data for the given `_ID` and all child nodes. - `validate(_ID: tuple[int, ...] = ()) -> DeepTrackNode` - Validate the data for the given `_ID`, marking it as up-to-date, but + `validate(_ID) -> DeepTrackNode` + Validate the data for the given `_ID`, marking it as up-to-date, but not its children. `update() -> DeepTrackNode` Reset the data. - `set_value(value: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode` - Set a value for the given `_ID`. If the new value differs from the - current value, the node is invalidated to ensure dependencies are + `set_value(value, _ID) -> DeepTrackNode` + Set a value for the given `_ID`. If the new value differs from the + current value, the node is invalidated to ensure dependencies are recomputed. - `recurse_children(memory: set[DeepTrackNode] | None = None) -> set[DeepTrackNode]` + `print_children_tree(indent) -> None` + Print a tree of all child nodes (recursively) for debugging. + `recurse_children() -> set[DeepTrackNode]` Return all child nodes in the dependency tree rooted at this node. - `recurse_dependencies(memory: list[DeepTrackNode] | None = None) -> Iterator[DeepTrackNode]` + `print_dependencies_tree(indent) -> None` + Print a tree of all parent nodes (recursively) for debugging. + `recurse_dependencies() -> Iterator[DeepTrackNode]` Yield all nodes that this node depends on, traversing dependencies. `get_citations() -> set[str]` Return a set of citations for this node and its dependencies. - `__call__(_ID: tuple[int, ...] = ()) -> Any` - Evaluate the node's computation for the given `_ID`, recomputing if + `__call__(_ID) -> Any` + Evaluate the node's computation for the given `_ID`, recomputing if necessary. - `current_value(_ID: tuple[int, ...] = ()) -> Any` - Return the currently stored value for the given `_ID` without + `current_value(_ID) -> Any` + Return the currently stored value for the given `_ID` without recomputation. + `new(_ID) -> Any` + Reset and recompute the value of this node at the given `_ID`. `__hash__() -> int` Return a unique hash for this node. - `__getitem__(idx: Any) -> DeepTrackNode` + `__getitem__(idx) -> DeepTrackNode` Creates a new node that indexes into this node's computed data. + `__repr__(self) -> str:` + Return a string representation of the node. Supported Operators ------------------- - DeepTrackNode supports the following Python operators: + `DeepTrackNode` supports the following Python operators: Arithmetic: + Addition (__add__, __radd__) @@ -722,26 +927,61 @@ class DeepTrackNode: // Floor division (__floordiv__, __rfloordiv__) Comparison: - < Less than (__lt__, __rlt__) - <= Less than or equal (__le__, __rle__) - > Greater than (__gt__, __rgt__) - >= Greater than or equal (__ge__, __rge__) + < Less than (__lt__, __gt__) + > Greater than (__gt__, __lt__) + <= Less than or equal (__le__, __ge__) + >= Greater than or equal (__ge__, __le__) - Each operation returns a new DeepTrackNode representing the - result of the corresponding operation in the computation graph. + Each operation returns a new `DeepTrackNode` representing the result of the + corresponding operation in the computation graph. - Example - ------- + Examples + -------- >>> from deeptrack.backend.core import DeepTrackNode - Create two `DeepTrackNode` objects, one as a parent and one as a child: - - >>> parent = DeepTrackNode(action=lambda: 10) - >>> child = DeepTrackNode(action=lambda _ID=None: parent(_ID) * 2) + Create three `DeepTrackNode` objects, as parent, child, and grandchild: + >>> parent = DeepTrackNode( + ... node_name="parent", + ... action=lambda: 10, + ... ) + >>> child = DeepTrackNode( + ... node_name="child", + ... action=lambda _ID=None: parent(_ID) * 2, + ... ) + >>> grandchild = DeepTrackNode( + ... node_name="grandchild", + ... action=lambda _ID=None: child(_ID) * 3, + ... ) >>> parent.add_child(child) + >>> child.add_child(grandchild) + + Check all children of `parent` (includes `parent` itself): + >>> for node in parent.recurse_children(): + ... print(node) + DeepTrackNode(name='parent', len=0, action=) + DeepTrackNode(name='child', len=0, action=) + DeepTrackNode(name='grandchild', len=0, action=) + + Print the children tree: + >>> parent.print_children_tree() + - DeepTrackNode 'parent' at 0x334202650 + - DeepTrackNode 'child' at 0x334201cf0 + - DeepTrackNode 'grandchild' at 0x334201ea0 + + Check all dependencies of `grandchild` (includes `grandchild` itself): + >>> for node in grandchild.recurse_dependencies(): + ... print(node) + DeepTrackNode(name='grandchild', len=0, action=) + DeepTrackNode(name='child', len=0, action=) + DeepTrackNode(name='parent', len=0, action=) + + Print the dependency tree: + >>> grandchild.print_dependencies_tree() + - DeepTrackNode 'grandchild' at 0x334201ea0 + - DeepTrackNode 'child' at 0x334201cf0 + - DeepTrackNode 'parent' at 0x334202650 Store and retrieve data for specific _IDs: - >>> parent.store(15, _ID=(0,)) >>> parent.store(20, _ID=(1,)) >>> parent.current_value((0,)) @@ -749,58 +989,64 @@ class DeepTrackNode: >>> parent.current_value((1,)) 20 - Compute and retrieve the value for the child node: - + Compute and retrieve the value for the child and grandchild node: >>> child(_ID=(0,)) 30 >>> child(_ID=(1,)) 40 + >>> grandchild(_ID=(0,)) + 90 + >>> grandchild(_ID=(1,)) + 120 Validation and invalidation: - >>> parent.is_valid((0,)) True >>> child.is_valid((0,)) True + >>> grandchild.is_valid((0,)) + True - >>> parent.invalidate((0,)) + >>> parent.invalidate((0,)) # Also invalidate child and grandchild >>> parent.is_valid((0,)) False >>> child.is_valid((0,)) False + >>> grandchild.is_valid((0,)) + False - >>> parent.validate((0,)) + >>> child.validate((0,)) >>> parent.is_valid((0,)) - True + False >>> child.is_valid((0,)) + True + >>> grandchild.is_valid((0,)) False Setting a value and automatic invalidation: - >>> parent.current_value((0,)) 15 - >>> child((1,)) # Computes and stores the value in child - >>> child.current_value((0,)) - 30 + >>> grandchild((0,)) # Computes and stores the value in grandchild + >>> grandchild.current_value((0,)) + 90 >>> parent.set_value(42, _ID=(0,)) >>> parent.current_value((0,)) 42 - >>> child((0,)) # Recomputes and stores the value in child - >>> child.current_value((0,)) - 84 + >>> grandchild((0,)) # Recomputes and stores the value in grandchild + >>> grandchild.current_value((0,)) + 252 Resetting all data in the dependency tree (recomputation required): + >>> grandchild.update() + >>> grandchild() + 60 - >>> parent.update() - - Dependency graph traversal (children and dependencies): - - >>> all_children = parent.recurse_children() - >>> all_dependencies = list(child.recurse_dependencies()) + This is equivalent to: + >>> grandchild.new() + 60 Operator overloading—arithmetic and comparison: - >>> node_a = DeepTrackNode(lambda: 5) >>> node_b = DeepTrackNode(lambda: 3) @@ -833,37 +1079,56 @@ class DeepTrackNode: True Indexing into computed data: - >>> vector_node = DeepTrackNode(lambda: [10, 20, 30]) >>> first_element = vector_node[0] >>> first_element() 10 - Citations for a node and its dependencies: + Accessing a value before computing it raises an error: + >>> new_node = DeepTrackNode(lambda: 123) + >>> new_node.is_valid((42,)) + False + >>> new_node.current_value((42,)) + KeyError: 'Attempting to index an empty dict.' + Working with nested _ID slicing: + >>> parent = DeepTrackNode(lambda: 5) + >>> child = DeepTrackNode(lambda _ID=None: parent(_ID[:1]) + _ID[1]) + >>> parent.add_child(child) + >>> child((0, 3)) # Equivalent to parent((0,)) + 3 + 8 + + Citations for a node and its dependencies: >>> parent.get_citations() # Set of citation strings - {...} + {...} """ - # Attributes. + node_name: str | None data: DeepTrackDataDict - children: WeakSet[DeepTrackNode] - dependencies: WeakSet[DeepTrackNode] + _children: WeakSet[DeepTrackNode] + _dependencies: WeakSet[DeepTrackNode] + _all_children: WeakSet[DeepTrackNode] + _all_dependencies: WeakSet[DeepTrackNode] + _action: Callable[..., Any] _accepts_ID: bool - _all_children: set[DeepTrackNode] # Citations associated with DeepTrack2. _citations: list[str] = [CITATION_MIDTVEDT2021QUANTITATIVE] @property def action(self: DeepTrackNode) -> Callable[..., Any]: - """Callable: The function that computes this node's value. + """Get the function used to compute this node's value. When accessed, it returns the current action. This is often a function - or lambda-function that takes `_ID` as an optional parameter if - `_accepts_ID` is `True`. + or lambda-function that takes `_ID` as an optional parameter if + `_accepts_ID` is True. + + Returns + ------- + Callable[..., Any] + The function used to compute this node's value. """ @@ -882,7 +1147,7 @@ def action( A function or lambda-function used for computing the node's value. If the function's signature includes `_ID`, this node will pass `_ID` when calling `action`. - + """ self._action = _action @@ -890,7 +1155,8 @@ def action( def __init__( self: DeepTrackNode, - action: Callable[..., Any] | None = None, + action: Callable[..., Any] | Any = None, + node_name: str | None = None, **kwargs: Any, ): """Initialize a new DeepTrackNode. @@ -898,71 +1164,99 @@ def __init__( Parameters ---------- action: Callable or Any, optional - Action to compute this node's value. If not provided, uses a no-op + Action to compute this node's value. If not provided, uses a no-op action (lambda: None). - **kwargs: dict[str, Any] + node_name: str or None, optional + Optional name for the node. Defaults to `None`. + **kwargs: Any Additional arguments for subclasses or extended functionality. - + """ + # Call super init in case of multiple inheritance. + super().__init__(**kwargs) + + # Initialize attributes. + self.node_name = node_name self.data = DeepTrackDataDict() - self.children = WeakSet() - self.dependencies = WeakSet() - self._action = lambda: None # Default no-op action + self._children = WeakSet() + self._dependencies = WeakSet() # If action is provided, set it. # If it's callable, use it directly; # otherwise, wrap it in a lambda. - if action is not None: - if callable(action): - self.action = action - else: - self.action = lambda: action + if callable(action): + self._action = action + else: + self._action = lambda: action # Check if action accepts `_ID`. self._accepts_ID = "_ID" in get_kwarg_names(self.action) - # Call super init in case of multiple inheritance. - super().__init__(**kwargs) - # Keep track of all children, including this node. - self._all_children = set() + self._all_children = WeakSet() #TODO ***BM*** Ok WeakSet from set? self._all_children.add(self) + # Keep track of all dependencies, including this node. + self._all_dependencies = WeakSet() #TODO ***BM*** Ok this addition? + self._all_dependencies.add(self) + def add_child( self: DeepTrackNode, child: DeepTrackNode, ) -> DeepTrackNode: """Add a child node to the current node. - Adding a child also updates `_all_children` for this node and all - its dependencies. It also ensures that dependency and child - relationships remain consistent. + Adds `child` to `self._children`, and `self` to `child._dependencies`. + Also updates `_all_children` for `self` and its dependencies, as well + as `_all_dependencies` for `self` and its children. Parameters ---------- child: DeepTrackNode The child node that depends on this node. - + Returns ------- self: DeepTrackNode Return the current node for chaining. + Raises + ------ + ValueError + If adding this child would introduce a cycle in the dependency + graph. + """ - self.children.add(child) - if self not in child.dependencies: - child.add_dependency(self) # Ensure bidirectional relationship. + # Check for cycle: if `self` is already in `child`'s dependency tree + if self in child.recurse_children(): + raise ValueError( + f"Adding {child.node_name} as child to {self.node_name} " + f"would create a cycle." + ) + + self._children.add(child) + child._dependencies.add(self) # Ensure bidirectional relationship - # Get all children of `child` and add `child` itself. - children = child._all_children.copy() - children.add(child) + # Get all children of `child`, which includes `child` itself. + child_all_children = child._all_children.copy() # Merge all these children into this node's subtree. - self._all_children = self._all_children.union(children) + self._all_children = self._all_children.union(child_all_children) for parent in self.recurse_dependencies(): - parent._all_children = parent._all_children.union(children) + parent._all_children = \ + parent._all_children.union(child_all_children) + + # Get all dependencies of `self`, which includes `self` itself. + self_all_dependencies = self._all_dependencies.copy() + + # Merge all these dependencies into the child's subtree. + child._all_dependencies = \ + child._all_dependencies.union(self_all_dependencies) + for grandchild in child.recurse_children(): + grandchild._all_dependencies = \ + grandchild._all_dependencies.union(self_all_dependencies) return self @@ -970,24 +1264,26 @@ def add_dependency( self: DeepTrackNode, parent: DeepTrackNode, ) -> DeepTrackNode: - """Adds a dependency, making this node depend on a parent node. + """Add a dependency, making this node depend on a parent node. + + Adds `parent` to `self._dependencies` and `self` to `parent._children`. + Also updates `_all_children` for `parent` and its dependencies, as well + as `_all_dependencies` for `self` and its children. Parameters ---------- parent: DeepTrackNode - The parent node that this node depends on. If `parent` changes, - this node's data may become invalid. + The parent node that this node depends on. If `parent` changes, + this node's data becomes invalid. Returns ------- self: DeepTrackNode Return the current node for chaining. - - """ - self.dependencies.add(parent) + """ - parent.add_child(self) # Ensure the child relationship is also set. + parent.add_child(self) return self @@ -1003,19 +1299,20 @@ def store( data: Any The data to be stored. _ID: tuple[int, ...], optional - The index for this data. If the _ID does not exist, it creates it. - Default is the empty tuple (), indicating a root-level entry. + The index for this data. If `_ID` does not exist, it creates it. + Defaults to (), indicating a root-level entry. Returns ------- self: DeepTrackNode Return the current node for chaining. - + """ - # Create the index if necessary, then store data in it. + # Create the index if necessary self.data.create_index(_ID) + # Then store data in it self.data[_ID].store(data) return self @@ -1024,7 +1321,7 @@ def is_valid( self: DeepTrackNode, _ID: tuple[int, ...] = (), ) -> bool: - """Check if data for the given _ID is valid. + """Check whether data for the given _ID is valid. Parameters ---------- @@ -1035,7 +1332,7 @@ def is_valid( ------- bool `True` if data at `_ID` is valid, otherwise `False`. - + """ try: @@ -1058,7 +1355,7 @@ def valid_index( ------- bool `True` if `_ID` is valid, otherwise `False`. - + """ return self.data.valid_index(_ID) @@ -1069,26 +1366,24 @@ def invalidate( ) -> DeepTrackNode: """Mark this node's data and all its children's data as invalid. + NOTE: At the moment, the code to invalidate specific `_ID`s is not + implemented, so the `_ID` parameter is not effectively used. + TODO: Implement the invalidation of specific `_ID`s. + Parameters ---------- _ID: tuple[int, ...], optional - The _ID to invalidate. Default is empty tuple, indicating + The _ID to invalidate. Default is empty tuple, indicating potentially the full dataset. Returns ------- self: DeepTrackNode Return the current node for chaining. - - Note - ---- - At the moment, the code to invalidate specific _IDs is not implemented, - so the _ID parameter is not effectively used. """ # Invalidate data for all children of this node. - for child in self.recurse_children(): child.data.invalidate() @@ -1103,7 +1398,7 @@ def validate( Parameters ---------- _ID: tuple[int, ...], optional - The _ID to validate. Default is empty tuple. + The _ID to validate. Defaults to empty tuple. Returns ------- @@ -1118,25 +1413,21 @@ def validate( def update(self: DeepTrackNode) -> DeepTrackNode: """Reset data in all children. - This method resets `data` for all children of each dependency, - effectively clearing cached values to force a recomputation on the next + This method resets `data` for all children of each dependency, + effectively clearing cached values to force a recomputation on the next evaluation. - + Returns ------- self: DeepTrackNode Return the current node for chaining. - - """ - # Pre-instantiate memory for optimization, - # used to avoid repeated processing of the same nodes. - child_memory = [] + """ # For each dependency, reset data in all of its children. for dependency in self.recurse_dependencies(): - for dep_child in dependency.recurse_children(memory=child_memory): - dep_child.data = DeepTrackDataDict() + for dependency_child in dependency.recurse_children(): + dependency_child.data = DeepTrackDataDict() return self @@ -1147,7 +1438,7 @@ def set_value( ) -> DeepTrackNode: """Set a value for this node's data at _ID. - If the value is different from the currently stored one (or if it is + If the value is different from the currently stored one (or if it is invalid), it will invalidate the old data before storing the new one. Parameters @@ -1155,13 +1446,13 @@ def set_value( value: Any The value to store. _ID: tuple[int, ...], optional - The _ID at which to store the value. + The `_ID` at which to store the value. Returns ------- self: DeepTrackNode Return the current node for chaining. - + """ # Check if current value is equivalent. If not, invalidate and store @@ -1175,22 +1466,29 @@ def set_value( return self - - def recurse_children( - self: DeepTrackNode, - memory: set[DeepTrackNode] | None = None, - ) -> set[DeepTrackNode]: - """Return all children of this node. + def print_children_tree(self: DeepTrackNode, indent: int = 0) -> None: + """Print a tree of all child nodes (recursively) for debugging. Parameters ---------- - memory: set, optional - Set of nodes that have already been visited (not used directly - here). + indent: int, optional + The indentation level (used internally during recursion). + + """ + + prefix = " " * (indent * 4) + name = f"{self.node_name!r}" if self.node_name else "" + print(f"{prefix}- {self.__class__.__name__} {name} at {hex(id(self))}") + + for child in self._children: + child.print_children_tree(indent=indent + 1) + + def recurse_children(self: DeepTrackNode) -> WeakSet[DeepTrackNode]: + """Return all children of this node. Returns ------- - set + WeakSet[DeepTrackNode] All nodes in the subtree rooted at this node, including itself. """ @@ -1207,7 +1505,7 @@ def old_recurse_children( Parameters ---------- memory: list, optional - A list to remember visited nodes, ensuring that each node is + A list to remember visited nodes, ensuring that each node is yielded only once. Yields @@ -1236,14 +1534,44 @@ def old_recurse_children( yield self # Recursively traverse children. - for child in self.children: + for child in self._children: yield from child.recurse_children(memory=memory) - def recurse_dependencies( + def print_dependencies_tree(self: DeepTrackNode, indent: int = 0) -> None: + """Print a tree of all parent nodes (recursively) for debugging. + + Parameters + ---------- + indent: int, optional + The indentation level (used internally during recursion). + + """ + + prefix = " " * (indent * 4) + name = f"{self.node_name!r}" if self.node_name else "" + print(f"{prefix}- {self.__class__.__name__} {name} at {hex(id(self))}") + + for parent in self._dependencies: + parent.print_dependencies_tree(indent=indent + 1) + + def recurse_dependencies(self: DeepTrackNode) -> WeakSet[DeepTrackNode]: + """Return all dependencies of this node. + + Returns + ------- + WeakSet[DeepTrackNode] + All the dependencies of this node, including itself. + + """ + + # Simply return `_all_dependencies` as it's maintained incrementally. + return self._all_dependencies + + def old_recurse_dependencies( self: DeepTrackNode, memory: list[DeepTrackNode] | None = None, ) -> Iterator[DeepTrackNode]: - """Yield all dependencies of this node, ensuring each is visited once. + """Legacy recursive method for traversing all dependencies. Parameters ---------- @@ -1254,7 +1582,11 @@ def recurse_dependencies( ------ DeepTrackNode Yields this node and all nodes it depends on. - + + Notes + ----- + This method is kept for backward compatibility or debugging purposes. + """ # On first call, instantiate memory. @@ -1272,20 +1604,20 @@ def recurse_dependencies( yield self # Recursively yield dependencies. - for dependency in self.dependencies: + for dependency in self._dependencies: yield from dependency.recurse_dependencies(memory=memory) def get_citations(self: DeepTrackNode) -> set[str]: """Get citations from this node and all its dependencies. - It gathers citations from this node and all nodes that it depends on. + Gathers citations from this node and all nodes that it depends on. Citations are stored as the class attribute `_citations`. Returns ------- set[str] Set of all citations relevant to this node and its dependency tree. - + """ # Initialize citations as a set of elements from self.citations. @@ -1309,13 +1641,15 @@ def __call__( ) -> Any: """Evaluate this node at _ID. - If the data at `_ID` is valid, it returns the stored value. Otherwise, - it calls `action` to compute a new value, stores it, and returns it. + If valid data is already stored at `_ID`, it is returned. Otherwise, + the node's `action` function is called to compute the value, which is + then stored and returned. The `_ID` is passed to `action` only if it + is declared to accept it. Parameters ---------- _ID: tuple[int, ...], optional - The _ID at which to evaluate the node's action. + The `_ID` at which to evaluate the node's action. Defaults to `()`. Returns ------- @@ -1324,6 +1658,7 @@ def __call__( """ + # First try to return the already stored value, if it's valid. if self.is_valid(_ID): try: return self.current_value(_ID) @@ -1339,6 +1674,7 @@ def __call__( # Store the newly computed value. self.store(new_value, _ID=_ID) + # Return the newly stored value. return self.current_value(_ID) def current_value( @@ -1350,22 +1686,46 @@ def current_value( Parameters ---------- _ID: tuple[int, ...], optional - The _ID at which to retrieve the current value. + The `_ID` at which to retrieve the current value. Defaults to `()`. Returns ------- Any The currently stored value for `_ID`. - + """ return self.data[_ID].current_value() + def new( + self: DeepTrackNode, + _ID: tuple[int, ...] = (), + ) -> Any: + """Reset and recompute the value of this node at the given _ID. + + Clears the stored data in this node and its dependencies, then + immediately computes and returns the new value for the given `_ID`. + + Parameters + ---------- + _ID: tuple[int, ...], optional + The identifier for which the value should be recomputed. Defaults to + an empty tuple. + + Returns + ------- + Any + The newly computed value at the given `_ID`. + + """ + + return self.update()(_ID) + def __hash__(self: DeepTrackNode) -> int: """Return a unique hash for this node. Uses the node's `id` to ensure uniqueness. - + """ return id(self) @@ -1376,6 +1736,10 @@ def __getitem__( ) -> DeepTrackNode: """Allow indexing into the node's computed data. + NOTE: This effectively creates a node that corresponds to + `self(...)[idx]`, allowing to select parts of the computed data + dynamically. + Parameters ---------- idx: Any @@ -1384,24 +1748,56 @@ def __getitem__( Returns ------- DeepTrackNode - A new node that, when evaluated, applies `idx` to the result of + A new node that, when evaluated, applies `idx` to the result of `self`. - Notes - ----- - This effectively creates a node that corresponds to `self(...)[idx]`, - allowing you to select parts of the computed data dynamically. - """ # Create a new node whose action indexes into this node's result. node = DeepTrackNode(lambda _ID=None: self(_ID=_ID)[idx]) self.add_child(node) - # node.add_dependency(self) # Already executed by add_child. return node + def __repr__(self: DeepTrackNode) -> str: + """Return a string representation of the node. + + This method returns a concise textual description of the node for + debugging and introspection. The string includes: + + - The node's class name (`DeepTrackNode`) + - Its `name`, if provided + - The number of stored data entries (`len`) + - The name of the action function or type (`action`) + - The list of stored `_ID`s (excluding the root `()`), if any exist + + Returns + ------- + str + A string in the format: "DeepTrackNode(name='', len=, + action=, IDs=[...])" Fields `name=...` and `IDs=[...]` + are included only if applicable. + + """ + + action_name = getattr( + self._action, + "__name__", + type(self._action).__name__, + ) + + ID_list = [_ID for _ID in self.data.dict if _ID != tuple()] + + parts = [ + f"name='{self.node_name}'" if self.node_name else None, + f"len={len(self.data)}", + f"action={action_name}", + f"IDs={ID_list}" if ID_list else None, + ] + + return f"{self.__class__.__name__}({', '.join(p for p in parts if p)})" + # Node-node operators. # These methods define arithmetic and comparison operations for # DeepTrackNode objects. Each operation creates a new DeepTrackNode that @@ -1416,7 +1812,7 @@ def __add__( """Add node to another node or value. Creates a new `DeepTrackNode` representing the addition of the values - produced by this node (`self`) and another node or value (`other`). + produced by the `self` node and the `other` node or value. Parameters ---------- @@ -1426,20 +1822,20 @@ def __add__( Returns ------- DeepTrackNode - A new node that represents the addition operation (`self + other`). - + A new node that represents the addition operation `self + other`. + """ return _create_node_with_operator(operator.__add__, self, other) def __radd__( self: DeepTrackNode, - other: DeepTrackNode | Any, + other: Any, ) -> DeepTrackNode: """Add other value to node (right-hand). - Creates a new `DeepTrackNode` representing the addition of another - node or value (`other`) to the value produced by this node (`self`). + Creates a new `DeepTrackNode` representing the addition of the `other` + value and the `self` node. Parameters ---------- @@ -1449,8 +1845,8 @@ def __radd__( Returns ------- DeepTrackNode - A new node that represents the addition operation (`other + self`). - + A new node that represents the addition operation `other + self`. + """ return _create_node_with_operator(operator.__add__, other, self) @@ -1459,11 +1855,10 @@ def __sub__( self: DeepTrackNode, other: DeepTrackNode | Any, ) -> DeepTrackNode: - """Subtract another node or value from node. + """Subtract a node from another node or value. - Creates a new `DeepTrackNode` representing the subtraction of the - values produced by another node or value (`other`) from this node - (`self`). + Creates a new `DeepTrackNode` representing the subtraction of the + values produced by the `self`node and the `other` node or value. Parameters ---------- @@ -1473,33 +1868,33 @@ def __sub__( Returns ------- DeepTrackNode - A new node that represents the subtraction operation - (`self - other`). - + A new node that represents the subtraction operation + `self - other`. + """ return _create_node_with_operator(operator.__sub__, self, other) def __rsub__( self: DeepTrackNode, - other: DeepTrackNode | Any, + other: Any, ) -> DeepTrackNode: """Subtract node from other value (right-hand). Creates a new `DeepTrackNode` representing the subtraction of the value - produced by this node (`self`) from another node or value (`other`). + produced by the `other` value from the `self` node. Parameters ---------- - other: DeepTrackNode or Any + other: Any The value or node to subtract from. Returns ------- DeepTrackNode - A new node that represents the subtraction operation - `other - self`). - + A new node that represents the subtraction operation + `other - self`. + """ return _create_node_with_operator(operator.__sub__, other, self) @@ -1510,9 +1905,8 @@ def __mul__( ) -> DeepTrackNode: """Multiply node by another node or value. - Creates a new `DeepTrackNode` representing the multiplication of the - values produced by this node (`self`) and another node or value - (`other`). + Creates a new `DeepTrackNode` representing the multiplication of the + values produced by the `self` node and the `other` node or value. Parameters ---------- @@ -1522,34 +1916,35 @@ def __mul__( Returns ------- DeepTrackNode - A new node that represents the multiplication operation - (`self * other`). - + A new node that represents the multiplication operation + `self * other`. + """ return _create_node_with_operator(operator.__mul__, self, other) def __rmul__( self: DeepTrackNode, - other: DeepTrackNode | Any, + other: Any, ) -> DeepTrackNode: """Multiply other value by node (right-hand). - Creates a new `DeepTrackNode` representing the multiplication of - another node or value (`other`) by the value produced by this node - (`self`). + Creates a new `DeepTrackNode` representing the multiplication of the + `other` value by the self node. Parameters ---------- - other: DeepTrackNode or Any + other: Any The value or node to multiply. Returns ------- DeepTrackNode - A new node that represents the multiplication operation - (`other * self`). + A new node that represents the multiplication operation + `other * self`. + """ + return _create_node_with_operator(operator.__mul__, other, self) def __truediv__( @@ -1559,7 +1954,7 @@ def __truediv__( """Divide node by another node or value. Creates a new `DeepTrackNode` representing the division of the value - produced by this node (`self`) by another node or value (`other`). + produced by the `self` node by the `other` node or value. Parameters ---------- @@ -1570,30 +1965,30 @@ def __truediv__( ------- DeepTrackNode A new node that represents the division operation (`self / other`). - + """ return _create_node_with_operator(operator.__truediv__, self, other) def __rtruediv__( self: DeepTrackNode, - other: DeepTrackNode | Any, + other: Any, ) -> DeepTrackNode: """Divide other value by node (right-hand). - Creates a new `DeepTrackNode` representing the division of another - node or value (`other`) by the value produced by this node (`self`). + Creates a new `DeepTrackNode` representing the division of the `other` + value by the `self` node. Parameters ---------- - other: DeepTrackNode or Any + other: Any The value or node to divide. Returns ------- DeepTrackNode - A new node that represents the division operation (`other / self`). - + A new node that represents the division operation `other / self`. + """ return _create_node_with_operator(operator.__truediv__, other, self) @@ -1605,8 +2000,7 @@ def __floordiv__( """Perform floor division of node by another node or value. Creates a new `DeepTrackNode` representing the floor division of the - value produced by this node (`self`) by another node or value - (`other`). + value produced by the `self` node by the `other` node or value. Parameters ---------- @@ -1616,34 +2010,33 @@ def __floordiv__( Returns ------- DeepTrackNode - A new node that represents the floor division operation - (`self // other`). - + A new node that represents the floor division operation + `self // other`. + """ return _create_node_with_operator(operator.__floordiv__, self, other) def __rfloordiv__( self: DeepTrackNode, - other: DeepTrackNode | Any, + other: Any, ) -> DeepTrackNode: """Perform floor division of other value by node (right-hand). - Creates a new `DeepTrackNode` representing the floor division of - another node or value (`other`) by the value produced by this node - (`self`). + Creates a new `DeepTrackNode` representing the floor division of the + other value by the `self` node. Parameters ---------- - other: DeepTrackNode or Any + other: Any The value or node to divide. Returns ------- DeepTrackNode - A new node that represents the floor division operation - (`other // self`). - + A new node that represents the floor division operation + `other // self`. + """ return _create_node_with_operator(operator.__floordiv__, other, self) @@ -1652,10 +2045,10 @@ def __lt__( self: DeepTrackNode, other: DeepTrackNode | Any, ) -> DeepTrackNode: - """Check if node is less than another node or value. + """Check whether node is less than other node or value. - Creates a new `DeepTrackNode` representing the comparison of this node - (`self`) being less than another node or value (`other`). + Creates a new `DeepTrackNode` representing whether the `self` node is + less than the `other` node or value. Parameters ---------- @@ -1665,45 +2058,20 @@ def __lt__( Returns ------- DeepTrackNode - A new node that represents the comparison operation - (`self < other`). - - """ + A new node that represents the comparison `self < other`. - return _create_node_with_operator(operator.__lt__, self, other) - - def __rlt__( - self: DeepTrackNode, - other: DeepTrackNode | Any, - ) -> DeepTrackNode: - """Check if other value is less than node (right-hand). - - Creates a new `DeepTrackNode` representing the comparison of another - node or value (`other`) being less than this node (`self`). - - Parameters - ---------- - other: DeepTrackNode or Any - The value or node to compare. - - Returns - ------- - DeepTrackNode - A new node that represents the comparison operation - (`other < self`). - """ - return _create_node_with_operator(operator.__lt__, other, self) + return _create_node_with_operator(operator.__lt__, self, other) def __gt__( self: DeepTrackNode, other: DeepTrackNode | Any, ) -> DeepTrackNode: - """Check if node is greater than another node or value. + """Check whether node is greater than other node or value. - Creates a new `DeepTrackNode` representing the comparison of this node - (`self`) being greater than another node or value (`other`). + Creates a new `DeepTrackNode` representing whether the `self` node is + greater than the `other` node or value. Parameters ---------- @@ -1713,45 +2081,20 @@ def __gt__( Returns ------- DeepTrackNode - A new node that represents the comparison operation - (`self > other`). - - """ - - return _create_node_with_operator(operator.__gt__, self, other) - - def __rgt__( - self: DeepTrackNode, - other: DeepTrackNode | Any, - ) -> DeepTrackNode: - """Check if other value is greater than node (right-hand). - - Creates a new `DeepTrackNode` representing the comparison of another - node or value (`other`) being greater than this node (`self`). + A new node that represents the comparison `self > other`. - Parameters - ---------- - other: DeepTrackNode or Any - The value or node to compare. - - Returns - ------- - DeepTrackNode - A new node that represents the comparison operation - (`other > self`). - """ - return _create_node_with_operator(operator.__gt__, other, self) + return _create_node_with_operator(operator.__gt__, self, other) def __le__( self: DeepTrackNode, other: DeepTrackNode | Any, ) -> DeepTrackNode: - """Check if node is less than or equal to another node or value. + """Check whether node is less than or equal to other node or value. - Creates a new `DeepTrackNode` representing the comparison of this node - (`self`) being less than or equal to another node or value (`other`). + Creates a new `DeepTrackNode` representing whether the `self` node is + less than or equal to the `other` node or value. Parameters ---------- @@ -1761,86 +2104,66 @@ def __le__( Returns ------- DeepTrackNode - A new node that represents the comparison operation - (`self <= other`). - + A new node that represents the comparison `self <= other`. + """ return _create_node_with_operator(operator.__le__, self, other) - def __rle__( + def __ge__( self: DeepTrackNode, other: DeepTrackNode | Any, ) -> DeepTrackNode: - """Check if other value is less than or equal to node (right-hand). + """Check whether node is greater than or equal to other node or value. - Creates a new `DeepTrackNode` representing the comparison of another - node or value (`other`) being less than or equal to this node (`self`). + Creates a new `DeepTrackNode` representing whether the `self` node is + greater than or equal to the `other` node or value. Parameters ---------- other: DeepTrackNode or Any - The value or node to compare. + The node or value to compare with. Returns ------- DeepTrackNode - A new node that represents the comparison operation - (`other <= self`). - - """ + A new node that represents the comparison `self >= other`. - return _create_node_with_operator(operator.__le__, other, self) + """ - def __ge__( - self: DeepTrackNode, - other: DeepTrackNode | Any, - ) -> DeepTrackNode: - """Check if node is greater than or equal to another node or value. + return _create_node_with_operator(operator.__ge__, self, other) - Creates a new `DeepTrackNode` representing the comparison of this node - (`self`) being greater than or equal to another node or value - (`other`). + @property + def dependencies(self: DeepTrackNode) -> WeakSet[DeepTrackNode]: + """Access the dependencies of the node (read-only). - Parameters - ---------- - other: DeepTrackNode or Any - The node or value to compare with. + This property exploses the internal `_dependencies` attribute as a + public read-only interface. Returns ------- - DeepTrackNode - A new node that represents the comparison operation - (`self >= other`). + WeakSet[DeepTrackNode] + A weak set with the dependencies of this node. """ - return _create_node_with_operator(operator.__ge__, self, other) - - def __rge__( - self: DeepTrackNode, - other: DeepTrackNode | Any, - ) -> DeepTrackNode: - """Check if other value is greater than or equal to node (right-hand). + return self._dependencies - Creates a new `DeepTrackNode` representing the comparison of another - node or value (`other`) being greater than or equal to this node - (`self`). + @property + def children(self: DeepTrackNode) -> WeakSet[DeepTrackNode]: + """Access the children of the node (read-only). - Parameters - ---------- - other: DeepTrackNode or Any - The value or node to compare. + This property exploses the internal `_children` attribute as a public + read-only interface. Returns ------- - DeepTrackNode - A new node that represents the comparison operation - (`other >= self`). - + WeakSet[DeepTrackNode] + A weak set with the children of this node. + """ - return _create_node_with_operator(operator.__ge__, other, self) + return self._children def _equivalent( @@ -1849,16 +2172,17 @@ def _equivalent( ) -> bool: """Check if two objects are equivalent. - This internal helper function provides a basic implementation to determine + This internal helper function provides a basic implementation to determine equivalence between two objects: - - If `a` and `b` are the same object (identity check), they are considered + - If `a` and `b` are the same object (identity check), they are considered equivalent. - If both `a` and `b` are empty lists, they are considered equivalent. + Additional cases can be implemented as needed to refine this behavior. - For immutable built-in types like empty tuples, integers, and `None`, Python - may reuse the same object in memory. Thus, `a is b` may return True even if - the objects are created separately. + NOTE: For immutable built-in types like empty tuples, integers, and `None`, + Python may reuse the same object in memory. Thus, `a is b` may return + `True` even if the objects are created separately. Parameters ---------- @@ -1913,16 +2237,16 @@ def _create_node_with_operator( ) -> DeepTrackNode: """Create a new computation node using the given operator and operands. - This internal helper function constructs a `DeepTrackNode` obtained from - the application of the specified operator to two operands. If the operands + This internal helper function constructs a `DeepTrackNode` obtained from + the application of the specified operator to two operands. If the operands are not already `DeepTrackNode` instances, they are converted to nodes. - This function also establishes bidirectional relationships between the new + This function also establishes bidirectional relationships between the new node and its operands: - + - The new node is added as a child of the operands `a` and `b`. - The operands `a` and `b` are added as dependencies of the new node. - - The operator `op` is applied lazily, meaning it will be evaluated when + - The operator `op` is applied lazily, meaning it will be evaluated when the new node is called, for computational efficiency. Parameters @@ -1937,7 +2261,7 @@ def _create_node_with_operator( Returns ------- DeepTrackNode - A new `DeepTrackNode` containing the result of applying the operator + A new `DeepTrackNode` containing the result of applying the operator `op` to the values of nodes `a` and `b`. """ @@ -1958,9 +2282,4 @@ def _create_node_with_operator( a.add_child(new_node) b.add_child(new_node) - # Establish dependency relationships between the nodes. - # (Not needed because already done implicitly above.) - # new_node.add_dependency(a) - # new_node.add_dependency(b) - return new_node diff --git a/deeptrack/features.py b/deeptrack/features.py index a4bf2c709..9d75ee5d3 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -364,7 +364,7 @@ class Feature(DeepTrackNode): `store_properties(toggle: bool, recursive: bool) -> Feature` It controls whether the properties are stored in the output `Image` object. - `torch(device: torch.device or None, recursive: bool) -> 'Feature'` + `torch(device: torch.device or None, recursive: bool) -> Feature` It sets the backend to torch. `numpy(recursice: bool) -> Feature` It set the backend to numpy. diff --git a/deeptrack/tests/backend/test_core.py b/deeptrack/tests/backend/test_core.py index b143f6063..4ba8bae64 100644 --- a/deeptrack/tests/backend/test_core.py +++ b/deeptrack/tests/backend/test_core.py @@ -25,153 +25,322 @@ def test___all__(self): DeepTrackNode, ) + def test_DeepTrackDataObject(self): dataobj = core.DeepTrackDataObject() - # Test storing and validating data. + # Test default initialization + self.assertEqual(dataobj.current_value(), None) + self.assertEqual(dataobj.is_valid(), False) + + # Test storing and validating data dataobj.store(1) self.assertEqual(dataobj.current_value(), 1) self.assertEqual(dataobj.is_valid(), True) - # Test invalidating data. + # Test invalidating data dataobj.invalidate() self.assertEqual(dataobj.current_value(), 1) self.assertEqual(dataobj.is_valid(), False) - # Test validating data. + # Test validating data dataobj.validate() self.assertEqual(dataobj.current_value(), 1) self.assertEqual(dataobj.is_valid(), True) + # Test updating data + dataobj.store(2) + self.assertEqual(dataobj.current_value(), 2) + self.assertEqual(dataobj.is_valid(), True) + + def test_DeepTrackDataDict(self): - dataset = core.DeepTrackDataDict() + datadict = core.DeepTrackDataDict() + + # Test initial state + self.assertEqual(datadict.keylength, None) + self.assertFalse(datadict.dict) # Empty dict, {} + + # Create indices and store data + datadict.create_index((0, 0)) + datadict[(0, 0)].store({"image": [0, 0, 0], "label": (0, 0)}) + + datadict.create_index((0, 1)) + datadict[(0, 1)].store({"image": [0, 1, 1], "label": (0, 1)}) + + datadict.create_index((1, 0)) + datadict[(1, 0)].store({"image": [1, 0, 2], "label": (1, 0)}) + + datadict.create_index((1, 1)) + datadict[(1, 1)].store({"image": [1, 1, 3], "label": (1, 1)}) + + self.assertEqual(datadict.keylength, 2) + self.assertEqual(len(datadict), 4) + + self.assertIn((0, 0), datadict.dict) + self.assertIn((0, 0), datadict.keys()) + self.assertIn((0, 1), datadict.dict) + self.assertIn((0, 1), datadict.keys()) + self.assertIn((1, 0), datadict.dict) + self.assertIn((1, 0), datadict.keys()) + self.assertIn((1, 1), datadict.dict) + self.assertIn((1, 1), datadict.keys()) + + # Test retrieving stored data + self.assertEqual( + datadict[(0, 0)].current_value(), + {"image": [0, 0, 0], "label": (0, 0)}, + ) + self.assertEqual( + datadict[(0, 1)].current_value(), + {"image": [0, 1, 1], "label": (0, 1)}, + ) + self.assertEqual( + datadict[(1, 0)].current_value(), + {"image": [1, 0, 2], "label": (1, 0)}, + ) + self.assertEqual( + datadict[(1, 1)].current_value(), + {"image": [1, 1, 3], "label": (1, 1)}, + ) + + # Test validation and invalidation - all + for key, value in datadict.items(): + self.assertTrue(value.is_valid()) - # Test initial state. - self.assertEqual(dataset.keylength, None) - self.assertFalse(dataset.dict) # Empty dict, {} + datadict.invalidate() + for key, value in datadict.items(): + self.assertFalse(value.is_valid()) - # Create indices and store data. - dataset.create_index((0,)) - dataset[(0,)].store({"image": [1, 2, 3], "label": 0}) + datadict.validate() + for key, value in datadict.items(): + self.assertTrue(value.is_valid()) - dataset.create_index((1,)) - dataset[(1,)].store({"image": [4, 5, 6], "label": 1}) + # Test validation and invalidation - single node + self.assertTrue(datadict[(0, 0)].is_valid()) - self.assertEqual(dataset.keylength, 1) - self.assertEqual(len(dataset.dict), 2) - self.assertIn((0,), dataset.dict) - self.assertIn((1,), dataset.dict) + datadict[(0, 0)].invalidate() + for key, value in datadict.items(): + if key == (0, 0): + self.assertFalse(value.is_valid()) + else: + self.assertTrue(value.is_valid()) - # Test retrieving stored data. - self.assertEqual(dataset[(0,)].current_value(), - {"image": [1, 2, 3], "label": 0}) - self.assertEqual(dataset[(1,)].current_value(), - {"image": [4, 5, 6], "label": 1}) + datadict[(1, 1)].invalidate() + for key, value in datadict.items(): + if key == (0, 0) or key == (1, 1): + self.assertFalse(value.is_valid()) + else: + self.assertTrue(value.is_valid()) - # Test validation and invalidation - all. - self.assertTrue(dataset[(0,)].is_valid()) - self.assertTrue(dataset[(1,)].is_valid()) + datadict[(0, 0)].validate() + for key, value in datadict.items(): + if key == (1, 1): + self.assertFalse(value.is_valid()) + else: + self.assertTrue(value.is_valid()) - dataset.invalidate() - self.assertFalse(dataset[(0,)].is_valid()) - self.assertFalse(dataset[(1,)].is_valid()) + datadict[(1, 1)].validate() + for key, value in datadict.items(): + self.assertTrue(value.is_valid()) - dataset.validate() - self.assertTrue(dataset[(0,)].is_valid()) - self.assertTrue(dataset[(1,)].is_valid()) + # Test valid_index + self.assertFalse(datadict.valid_index(())) - # Test validation and invalidation - single node. - self.assertTrue(dataset[(0,)].is_valid()) + self.assertFalse(datadict.valid_index((0,))) - dataset[(0,)].invalidate() - self.assertFalse(dataset[(0,)].is_valid()) - self.assertTrue(dataset[(1,)].is_valid()) + self.assertTrue(datadict.valid_index((0, 0))) + self.assertTrue(datadict.valid_index((1, 1))) + self.assertTrue(datadict.valid_index((2, 2))) - dataset[(1,)].invalidate() - self.assertFalse(dataset[(0,)].is_valid()) - self.assertFalse(dataset[(1,)].is_valid()) + self.assertFalse(datadict.valid_index((0, 1, 2))) - dataset[(0,)].validate() - self.assertTrue(dataset[(0,)].is_valid()) - self.assertFalse(dataset[(1,)].is_valid()) + # Test slicing: __getitem__ with shorter _ID + sliced = datadict[(0,)] + self.assertIsInstance(sliced, dict) - dataset[(1,)].validate() - self.assertTrue(dataset[(0,)].is_valid()) - self.assertTrue(dataset[(1,)].is_valid()) + self.assertIn((0, 0), sliced) + self.assertIsInstance(sliced[(0, 0)], core.DeepTrackDataObject) + self.assertIn((0, 1), sliced) + self.assertIsInstance(sliced[(0, 1)], core.DeepTrackDataObject) + + # Test trimming: __getitem__ with longer _ID + for key, value in datadict.items(): + self.assertEqual( + datadict[key + (99,)].current_value(), + datadict[key].current_value(), + ) + + # Test items(), keys(), values() + for item, key, value in zip( + datadict.items(), datadict.keys(), datadict.values() + ): + self.assertEqual(item[0], key) + self.assertEqual(item[1], value) + + # Test dict property access + self.assertIs(datadict.dict[(0, 0)], datadict[(0, 0)]) - # Test iteration over entries. - for key, value in dataset.dict.items(): - self.assertIn(key, {(0,), (1,)}) - self.assertIsInstance(value, core.DeepTrackDataObject) def test_DeepTrackNode_basics(self): + ## Without _ID node = core.DeepTrackNode(action=lambda: 42) - # Evaluate the node. + # Evaluate the node result = node() # Value is calculated and stored. self.assertEqual(result, 42) - # Store a value. + # Store a value node.store(100) # Value is stored. self.assertEqual(node.current_value(), 100) self.assertTrue(node.is_valid()) - # Invalidate the node and check the value. + # Invalidate the node and check the value node.invalidate() self.assertFalse(node.is_valid()) - self.assertEqual(node.current_value(), 100) # Value is retrieved. + self.assertEqual(node.current_value(), 100) # Value is retrieved self.assertFalse(node.is_valid()) - self.assertEqual(node(), 42) # Value is calculated and stored. + self.assertEqual(node(), 42) # Value is calculated and stored self.assertTrue(node.is_valid()) + ## With _ID + node = core.DeepTrackNode(action=lambda _ID: _ID[0] * 10 + _ID[1]) + + # Store values + self.assertEqual(node((0, 0)), 0) + self.assertEqual(node((0, 1)), 1) + self.assertEqual(node((1, 0)), 10) + self.assertEqual(node((1, 1)), 11) + + # Check validity + self.assertFalse(node.is_valid()) + self.assertTrue(node.is_valid((0, 0))) + self.assertTrue(node.is_valid((0, 1))) + self.assertTrue(node.is_valid((1, 0))) + self.assertTrue(node.is_valid((1, 1))) + + # Invalidate + node.invalidate() + self.assertFalse(node.is_valid((0, 0))) + self.assertFalse(node.is_valid((0, 1))) + self.assertFalse(node.is_valid((1, 0))) + self.assertFalse(node.is_valid((1, 1))) + + def test_DeepTrackNode_new(self): + # Create a node with an action + node = core.DeepTrackNode(action=lambda: 42) + + # Manually store a different value + node.store(100) + self.assertEqual(node.current_value(), 100) + + # Call new() to reset and recompute + result = node.new() + self.assertEqual(result, 42) + self.assertEqual(node.current_value(), 42) + + # Also test with ID + node = core.DeepTrackNode(action=lambda _ID=None: _ID[0] * 2) + node.store(123, _ID=(3,)) + self.assertEqual(node.current_value((3,)), 123) + + result = node.new((3,)) + self.assertEqual(result, 6) + self.assertEqual(node.current_value((3,)), 6) + def test_DeepTrackNode_dependencies(self): - parent = core.DeepTrackNode(action=lambda: 10) - child = core.DeepTrackNode(action=lambda _ID=None: parent() * 2) - parent.add_child(child) # Establish dependency. + import random - # Check that the just create nodes are invalid as not calculated. + parent = core.DeepTrackNode( + node_name="parent", + action=lambda: 10, + ) + child = core.DeepTrackNode( + node_name="child", + action=lambda: parent() * 2, + ) + grandchild = core.DeepTrackNode( + node_name="grandchild", + action=lambda: child() * 3, + ) + + # Establish dependencies + if random.randint(0, 1): # Test add_child() + parent.add_child(child) + else: # Test add_dependency() + child.add_dependency(parent) + + if random.randint(0, 1): # Test add_child() + child.add_child(grandchild) + else: # Test add_dependency() + grandchild.add_dependency(child) + + # Check that the just created nodes are invalid as not calculated self.assertFalse(parent.is_valid()) self.assertFalse(child.is_valid()) + self.assertFalse(grandchild.is_valid()) # Calculate child, and therefore parent. - result = child() - self.assertEqual(result, 20) + self.assertEqual(grandchild(), 60) self.assertTrue(parent.is_valid()) self.assertTrue(child.is_valid()) + self.assertTrue(grandchild.is_valid()) # Invalidate parent and check child validity. parent.invalidate() self.assertFalse(parent.is_valid()) self.assertFalse(child.is_valid()) + self.assertFalse(grandchild.is_valid()) - # Validate parent and ensure child is invalid until recomputation. - parent.validate() - self.assertTrue(parent.is_valid()) - self.assertFalse(child.is_valid()) + # Recompute child and check its validity. + child.validate() + self.assertFalse(parent.is_valid()) + self.assertTrue(child.is_valid()) + self.assertFalse(grandchild.is_valid()) # Grandchild still invalid # Recompute child and check its validity - child() - self.assertTrue(parent.is_valid()) + grandchild() + self.assertFalse(parent.is_valid()) # Not recalculated as child valid self.assertTrue(child.is_valid()) + self.assertTrue(grandchild.is_valid()) - def test_DeepTrackNode_nested_dependencies(self): - parent = core.DeepTrackNode(action=lambda: 5) - middle = core.DeepTrackNode(action=lambda: parent() + 5) - child = core.DeepTrackNode(action=lambda: middle() * 2) - - parent.add_child(middle) - middle.add_child(child) - - result = child() - self.assertEqual(result, 20) - - # Invalidate the middle and check propagation. - middle.invalidate() + # Recompute child and check its validity + parent.invalidate() + grandchild() self.assertTrue(parent.is_valid()) - self.assertFalse(middle.is_valid()) - self.assertFalse(child.is_valid()) + self.assertTrue(child.is_valid()) + self.assertTrue(grandchild.is_valid()) + + # Check dependencies + self.assertEqual(len(parent.children), 1) + for node in parent.children: + self.assertEqual(node.node_name, "child") + self.assertEqual(len(child.children), 1) + for node in child.children: + self.assertEqual(node.node_name, "grandchild") + self.assertEqual(len(grandchild.children), 0) + + self.assertEqual(len(parent.dependencies), 0) + self.assertEqual(len(child.dependencies), 1) + for node in child.dependencies: + self.assertEqual(node.node_name, "parent") + self.assertEqual(len(grandchild.dependencies), 1) + for node in grandchild.dependencies: + self.assertEqual(node.node_name, "child") + + self.assertEqual(len(parent._all_children), 3) + self.assertEqual(len(child._all_children), 2) + self.assertEqual(len(grandchild._all_children), 1) + + self.assertEqual(len(parent.recurse_children()), 3) + self.assertEqual(len(child.recurse_children()), 2) + self.assertEqual(len(grandchild.recurse_children()), 1) + + self.assertEqual(len(parent.recurse_dependencies()), 1) + self.assertEqual(len(child.recurse_dependencies()), 2) + self.assertEqual(len(grandchild.recurse_dependencies()), 3) def test_DeepTrackNode_op_overloading(self): node1 = core.DeepTrackNode(action=lambda: 5) @@ -179,15 +348,66 @@ def test_DeepTrackNode_op_overloading(self): sum_node = node1 + node2 self.assertEqual(sum_node(), 15) - - diff_node = node2 - node1 - self.assertEqual(diff_node(), 5) + sum_node = node1 + 100 + self.assertEqual(sum_node(), 105) + sum_node = 100 + node2 + self.assertEqual(sum_node(), 110) + + diff_node = node1 - node2 + self.assertEqual(diff_node(), -5) + diff_node = node1 - 100 + self.assertEqual(diff_node(), -95) + diff_node = 100 - node2 + self.assertEqual(diff_node(), 90) prod_node = node1 * node2 self.assertEqual(prod_node(), 50) - - div_node = node2 / node1 - self.assertEqual(div_node(), 2) + prod_node = node1 * 100 + self.assertEqual(prod_node(), 500) + prod_node = 100 * node2 + self.assertEqual(prod_node(), 1_000) + + truediv_node = node2 / node1 + self.assertEqual(truediv_node(), 2) + truediv_node = node2 / 2 + self.assertEqual(truediv_node(), 5) + truediv_node = 50 / node1 + self.assertEqual(truediv_node(), 10) + + floordiv_node = node1 // node2 + self.assertEqual(floordiv_node(), 0) + floordiv_node = node1 // 2 + self.assertEqual(floordiv_node(), 2) + floordiv_node = 12 // node2 + self.assertEqual(floordiv_node(), 1) + + lt_node = node1 < node2 + self.assertTrue(lt_node()) + lt_node = node1 < 2 + self.assertFalse(lt_node()) + lt_node = 12 < node2 + self.assertFalse(lt_node()) + + gt_node = node1 > node2 + self.assertFalse(gt_node()) + gt_node = node1 > 2 + self.assertTrue(gt_node()) + gt_node = 12 > node2 + self.assertTrue(gt_node()) + + le_node = node1 < node2 + self.assertTrue(le_node()) + le_node = node1 < 2 + self.assertFalse(le_node()) + le_node = 12 < node2 + self.assertFalse(le_node()) + + ge_node = node1 > node2 + self.assertFalse(ge_node()) + ge_node = node1 > 2 + self.assertTrue(ge_node()) + ge_node = 12 > node2 + self.assertTrue(ge_node()) def test_DeepTrackNode_citations(self): node = core.DeepTrackNode(action=lambda: 42) @@ -224,16 +444,16 @@ def test_DeepTrackNode_nested_ids(self): parent.store(10, _ID=(1,)) # Compute child values for nested IDs - child_value_0_0 = child(_ID=(0, 0)) # Uses parent(_ID=(0,)). + child_value_0_0 = child(_ID=(0, 0)) # Uses parent(_ID=(0,)) self.assertEqual(child_value_0_0, 0) - child_value_0_1 = child(_ID=(0, 1)) # Uses parent(_ID=(0,)). + child_value_0_1 = child(_ID=(0, 1)) # Uses parent(_ID=(0,)) self.assertEqual(child_value_0_1, 5) - child_value_1_0 = child(_ID=(1, 0)) # Uses parent(_ID=(1,)). + child_value_1_0 = child(_ID=(1, 0)) # Uses parent(_ID=(1,)) self.assertEqual(child_value_1_0, 0) - child_value_1_1 = child(_ID=(1, 1)) # Uses parent(_ID=(1,)). + child_value_1_1 = child(_ID=(1, 1)) # Uses parent(_ID=(1,)) self.assertEqual(child_value_1_1, 10) def test_DeepTrackNode_replicated_behavior(self): @@ -251,7 +471,7 @@ def test_DeepTrackNode_replicated_behavior(self): def test_DeepTrackNode_parent_id_inheritance(self): - # Children with IDs matching than parents. + # Children with IDs matching those of the parents. parent_matching = core.DeepTrackNode(action=lambda: 10) child_matching = core.DeepTrackNode( action=lambda _ID=None: parent_matching(_ID[:1]) * 2 @@ -329,6 +549,7 @@ def test_DeepTrackNode_dependency_graph_with_ids(self): # 24 self.assertEqual(C_0_1_2, 24) + def test__equivalent(self): # Identity check (same object) a = [1, 2, 3] @@ -356,6 +577,7 @@ def test__equivalent(self): # One empty list, one non-list empty container self.assertFalse(core._equivalent([], ())) + def test__create_node_with_operator(self): import operator diff --git a/pyproject.toml b/pyproject.toml index 338d36dce..cc5d04060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,3 +22,6 @@ select = ["ALL"] [tool.ruff] line-length = 79 + +[tool.black] +line-length = 79 diff --git a/tutorials/2-examples/DTEx211_MNIST.ipynb b/tutorials/2-examples/DTEx211_MNIST.ipynb index e0f8f8b35..a8674ec90 100644 --- a/tutorials/2-examples/DTEx211_MNIST.ipynb +++ b/tutorials/2-examples/DTEx211_MNIST.ipynb @@ -164,7 +164,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, you combine them into a single dataset that returns a tuple with a MNIST image and the corresponding ground-truth digit value." + "Finally, you combine them into a single dataset that returns a tuple with an MNIST image and the corresponding ground-truth digit value." ] }, { diff --git a/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb b/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb index 5a19deb87..5a20ef492 100644 --- a/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb +++ b/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb @@ -69,33 +69,26 @@ "\n", "In DeepTrack2, nodes represent computational units that can be flexibly linked into graphs by defining dependencies. This allows you to build modular, traceable, and efficient pipelines where changes automatically propagate through the graph.\n", "\n", - "Below we show how to set up parent-child relationships between nodes, store and compute data, and propagate invalidation when the upstream data changes." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from deeptrack.backend.core import DeepTrackNode" + "Below you will see how to set up parent-child relationships between nodes, store and compute data, and propagate invalidation when the upstream data changes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 2.1. Creating Parent and Child Nodes\n", + "### 2.1. Creating the Parent and Child Nodes\n", "\n", - "We create a parent node and a child node whose value is always twice that of its parent." + "Create a parent node and a child node whose value is twice that of its parent." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ + "from deeptrack.backend.core import DeepTrackNode\n", + "\n", "# Create parent and child nodes\n", "parent = DeepTrackNode(action=lambda: 10)\n", "child = DeepTrackNode(action=lambda _ID=None: parent(_ID) * 2)" @@ -107,21 +100,21 @@ "source": [ "### 2.2. Establishing Parent-Child Dependency\n", "\n", - "We link the parent and child so that the child automatically tracks changes in the parent. In this way, the parent is updated or invalidated, this relationship ensures that the child is also kept up to date." + "Link the parent and child so that the child automatically tracks changes in the parent. In this way, this relationship ensures that the child is also kept up to date whenever the parent is updated or invalidated." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "DeepTrackNode(len=0, action=)" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -137,21 +130,21 @@ "source": [ "### 2.3. Storing Values and Computing Results\n", "\n", - "Let’s assign different values to the parent for different data indices (`_ID`)." + "You can assign different values to the parent at once by associating them with different data indices (`_ID`s)." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "DeepTrackNode(len=2, action=, IDs=[(0,), (1,)])" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -173,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -182,7 +175,7 @@ "30" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -193,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -202,7 +195,7 @@ "40" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -223,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -232,14 +225,14 @@ "30" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Retrieve the cached value without recomputing\n", - "child.current_value((0,))" + "# Retrieve the cached value without computing\n", + "child.current_value((0,)) # Raise KeyError if value not already computed" ] }, { @@ -248,12 +241,12 @@ "source": [ "### 2.5. Validation and Invalidation\n", "\n", - "When you invalidate the parent for a particular _ID, the child’s value for that _ID will also be marked as invalid (since it depends on the parent). This ensures that downstream computations are never out of sync." + "When you invalidate the parent for a particular `_ID`, the child’s value for that `_ID` will also be marked as invalid (since it depends on the parent). This ensures that downstream computations are never out of sync." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -262,20 +255,20 @@ "False" ] }, - "execution_count": 11, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Invalidate parent data for a given ID.\n", + "# Invalidate parent data for a given ID\n", "parent.invalidate((0,))\n", "parent.is_valid((0,))" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -284,7 +277,7 @@ "False" ] }, - "execution_count": 12, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -299,12 +292,12 @@ "source": [ "### 2.6. Updating and Recomputing Values\n", "\n", - "After invalidation, if we update the parent and request the child’s value again, it will be recomputed as needed." + "After invalidation, if you update the parent and request the child’s value again, it will be recomputed." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -313,7 +306,7 @@ "50" ] }, - "execution_count": 13, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -326,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -335,7 +328,7 @@ "True" ] }, - "execution_count": 14, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -346,7 +339,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -355,7 +348,7 @@ "True" ] }, - "execution_count": 15, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -375,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -384,7 +377,7 @@ "100" ] }, - "execution_count": 16, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -396,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -405,7 +398,7 @@ "True" ] }, - "execution_count": 17, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -416,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -425,7 +418,7 @@ "False" ] }, - "execution_count": 18, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -434,6 +427,340 @@ "child.is_valid((1,))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.8. Evaluating, Resetting, and Recomputing Node Values\n", + "\n", + "The value of a `DeepTrackNode` is evaluated and stored when the node is called with `.__call__()` (e.g., `node()`). Wehn you call a node multiple times, you will always get the same value as the node value is retrieved from memory (and not recomputed) each time.\n", + "\n", + "The state of the node (and its dependencies) can be reset using the `.update()` method. After this, calling the node will result in a recomputation of its value (e.g., `node.update()()`).\n", + "\n", + "For convenience, you can use the `.new()` method instead of `.update()()`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.8.1. Calling a Node\n", + "\n", + "You can evaluate a node by calling it with a specific (optional) `_ID`. This triggers the `__call__()` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "20" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parent = DeepTrackNode(lambda: 10)\n", + "child = DeepTrackNode(lambda _ID=None: parent(_ID) * 2)\n", + "parent.add_child(child)\n", + "\n", + "child((0,)) # Triggers computation for both parent and child" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the node already has valid stored data for the given `_ID`, it returns that directly (cached result). Otherwise, it computes the value using the action, stores it, and then returns it. Thus, repeated calls with the same _ID will return cached results, unless invalidated" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.8.2. Resetting a Node\n", + "\n", + "In most cases, a `DeepTrackNode` automatically caches and reuses previously computed values unless explicitly invalidated. However, sometimes you may want to force a fresh computation; for instance, when the node's output is stochastic or time-dependent.\n", + "\n", + "The `.update()` method is provided exactly for this purpose: It clears all stored data in a node and its entire dependency graph. This invalidates everything and removes all cached results." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20\n" + ] + } + ], + "source": [ + "parent.store(20, _ID=(0,))\n", + "print(parent((0,))) # Output: 20 (from cache)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After calling `.update()`, the next evaluation will recompute all values." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10\n" + ] + } + ], + "source": [ + "parent.update()\n", + "print(parent((0,))) # Output: 10 (recomputed from action)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.8.3. Using `.new()` to Recompute a Node's Value\n", + "\n", + "The `.new()` method is equivalent to using `.update()()`, i.e., the `.update()` method to clear the values sotred in the node and its dependencies followed by the `.__call__()` method to recompute the values from scratch." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For example, consider a node returning a random value." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "\n", + "node = DeepTrackNode(lambda _ID=None: random.randint(0, 100))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "32" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node.new()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "is equivalent to (but more elegant than)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "77" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node.update()()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.9. Creating and Visualizing a Complex Graph\n", + "\n", + "Now create a complex computational graph with the following structure:\n", + "\n", + " ┌────────┐ ┌────────┐\n", + " │ input1 │ │ input2 │\n", + " └────────┘ └────────┘\n", + " ↓ ↓\n", + " ┌────────────┐ ┌────────────┐\n", + " │ process1_A │ │ process2_A │\n", + " └────────────┘ └────────────┘\n", + " ↓ ↓\n", + " ┌────────────┐ ┌────────────┐\n", + " │ process1_B │ │ process2_B │\n", + " └────────────┘ └────────────┘\n", + " ↘ ↙\n", + " ┌────────────┐\n", + " │ merger │\n", + " └────────────┘\n", + " ↓\n", + " ┌────────────────┐\n", + " │ postprocess │\n", + " └────────────────┘\n", + " ↓ ↓\n", + " ┌──────┐ ┌──────┐\n", + " │ out1 │ │ out2 │\n", + " └──────┘ └──────┘" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# Shared input nodes\n", + "input1 = DeepTrackNode(lambda: 3, node_name=\"input1\")\n", + "input2 = DeepTrackNode(lambda: 5, node_name=\"input2\")\n", + "\n", + "# Independent preprocessing\n", + "process1_A = DeepTrackNode(lambda _ID=None: input1(_ID) + 1,\n", + " node_name=\"process1_A\")\n", + "process2_A = DeepTrackNode(lambda _ID=None: input2(_ID) * 2,\n", + " node_name=\"process2_A\")\n", + "\n", + "process1_B = DeepTrackNode(lambda _ID=None: process1_A(_ID) ** 2,\n", + " node_name=\"process1_B\")\n", + "process2_B = DeepTrackNode(lambda _ID=None: process2_A(_ID) - 1,\n", + " node_name=\"process2_B\")\n", + "\n", + "# Merge branch: sum both processed paths\n", + "merger = DeepTrackNode(\n", + " lambda _ID=None: process1_B(_ID) + process2_B(_ID),\n", + " node_name=\"merger\"\n", + ")\n", + "\n", + "# Post-processing\n", + "postprocess = DeepTrackNode(lambda _ID=None: merger(_ID) / 2,\n", + " node_name=\"postprocess\")\n", + "\n", + "# Split again\n", + "out1 = DeepTrackNode(lambda _ID=None: postprocess(_ID) + 100, node_name=\"out1\")\n", + "out2 = DeepTrackNode(lambda _ID=None: postprocess(_ID) * 3, node_name=\"out2\")\n", + "\n", + "# Link nodes\n", + "input1.add_child(process1_A)\n", + "input2.add_child(process2_A)\n", + "\n", + "process1_A.add_child(process1_B)\n", + "process2_A.add_child(process2_B)\n", + "\n", + "process1_B.add_child(merger)\n", + "process2_B.add_child(merger)\n", + "\n", + "merger.add_child(postprocess)\n", + "\n", + "postprocess.add_child(out1)\n", + "postprocess.add_child(out2);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can print the children of, for example, `input1`." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "- DeepTrackNode 'input1' at 0x337a34e50\n", + " - DeepTrackNode 'process1_A' at 0x337a34e80\n", + " - DeepTrackNode 'process1_B' at 0x337a351b0\n", + " - DeepTrackNode 'merger' at 0x337a35690\n", + " - DeepTrackNode 'postprocess' at 0x337a34160\n", + " - DeepTrackNode 'out1' at 0x337a367a0\n", + " - DeepTrackNode 'out2' at 0x337a37730\n" + ] + } + ], + "source": [ + "input1.print_children_tree()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also print the dependencies of, for example, `out2`." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "- DeepTrackNode 'out2' at 0x337a37730\n", + " - DeepTrackNode 'postprocess' at 0x337a34160\n", + " - DeepTrackNode 'merger' at 0x337a35690\n", + " - DeepTrackNode 'process1_B' at 0x337a351b0\n", + " - DeepTrackNode 'process1_A' at 0x337a34e80\n", + " - DeepTrackNode 'input1' at 0x337a34e50\n", + " - DeepTrackNode 'process2_B' at 0x337a35120\n", + " - DeepTrackNode 'process2_A' at 0x337a34f10\n", + " - DeepTrackNode 'input2' at 0x337a34df0\n" + ] + } + ], + "source": [ + "out2.print_dependencies_tree()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -442,7 +769,7 @@ "\n", "A powerful feature of DeepTrack2 nodes is lazy evaluation: the node’s value is only computed when it is needed, and the result is cached until the node (or its dependencies) is invalidated. This avoids redundant computations and ensures high efficiency, especially in large graphs.\n", "\n", - "In this example, we’ll use a global counter to demonstrate when the node’s computation actually happens." + "In this example, you will use a global counter to demonstrate when the node’s computation actually happens." ] }, { @@ -451,12 +778,12 @@ "source": [ "### 3.1 Defining a Node with a Side Effect\n", "\n", - "First, we define a calculation function that increments a global counter each time it is called. This allows us to see exactly how many times the node’s computation is performed." + "First, define a calculation function that increments a global counter each time it is called. This will allow you to see exactly how many times the node’s computation is performed." ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -481,7 +808,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -490,7 +817,7 @@ "10" ] }, - "execution_count": 20, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -502,7 +829,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -511,7 +838,7 @@ "1" ] }, - "execution_count": 21, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -522,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -531,7 +858,7 @@ "10" ] }, - "execution_count": 22, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -543,7 +870,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -552,7 +879,7 @@ "1" ] }, - "execution_count": 23, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -567,12 +894,12 @@ "source": [ "### 3.3 Invalidation Forces Recalculation\n", "\n", - "If we invalidate the node, the cache is cleared and the next call will recompute the value:" + "If you invalidate the node, the cache is cleared and the next call will recompute the value:" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -581,7 +908,7 @@ "10" ] }, - "execution_count": 24, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -594,7 +921,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -603,7 +930,7 @@ "2" ] }, - "execution_count": 25, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -614,7 +941,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -623,7 +950,7 @@ "10" ] }, - "execution_count": 26, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -636,7 +963,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -645,7 +972,7 @@ "3" ] }, - "execution_count": 27, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -654,15 +981,22 @@ "call_count" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You would obtain some similar results using the `.update()` or `.new()` methods." + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Data Management with IDs\n", "\n", - "In DeepTrack2, the `DeepTrackDataDict` class provides an efficient, validated way to manage multiple data objects, each indexed by a unique tuple of integers.\n", + "The `DeepTrackDataDict` class provides an efficient, validated way to manage multiple data objects, each indexed by a unique tuple of integers.\n", "\n", - "This is especially useful for working with multidimensional datasets, or for mapping results to experiment or batch indices." + "**NOTE:** This is particularly relevant when using the `Repeat` feature (accessed also throught the`^` operator)." ] }, { @@ -671,12 +1005,12 @@ "source": [ "### 4.1. Creating and Indexing Data Objects\n", "\n", - "You can create entries with arbitrary integer index tuples, just like keys in a nested dictionary." + "You can create entries with arbitrary integer index tuples, just like keys in a dictionary." ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -703,7 +1037,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -725,7 +1059,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -734,7 +1068,7 @@ "'Cat'" ] }, - "execution_count": 30, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -746,7 +1080,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -755,7 +1089,7 @@ "'Bird'" ] }, - "execution_count": 31, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -766,20 +1100,24 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 38, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{(0, 0): , (0, 1): }\n" - ] + "data": { + "text/plain": [ + "{(0, 0): DeepTrackDataObject(data='Cat', valid=True),\n", + " (0, 1): DeepTrackDataObject(data='Dog', valid=True)}" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "# Retrieve all entries whose indices start with (0,)\n", - "print(data_dict[(0, )])" + "data_dict[(0, )]" ] }, { @@ -804,7 +1142,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -816,7 +1154,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -825,7 +1163,7 @@ "8" ] }, - "execution_count": 34, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -837,7 +1175,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -846,7 +1184,7 @@ "2" ] }, - "execution_count": 35, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -858,7 +1196,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -867,7 +1205,7 @@ "10" ] }, - "execution_count": 36, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -879,7 +1217,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 43, "metadata": {}, "outputs": [ { @@ -888,7 +1226,7 @@ "1.6666666666666667" ] }, - "execution_count": 37, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -900,7 +1238,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 44, "metadata": {}, "outputs": [ { @@ -909,7 +1247,7 @@ "1" ] }, - "execution_count": 38, + "execution_count": 44, "metadata": {}, "output_type": "execute_result" } @@ -925,12 +1263,12 @@ "source": [ "### 5.2. Chaining and Nesting Operators\n", "\n", - "You can compose pipelines of arbitrary depth and complexity:" + "You can compose pipelines of arbitrary depth and complexity." ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 45, "metadata": {}, "outputs": [ { @@ -939,7 +1277,7 @@ "4.0" ] }, - "execution_count": 39, + "execution_count": 45, "metadata": {}, "output_type": "execute_result" } @@ -955,12 +1293,12 @@ "source": [ "### 5.3. Comparison Operators for Graphs\n", "\n", - "Comparison operators also work on nodes, returning new nodes that compute boolean results:" + "Comparison operators also work on nodes, returning new nodes that compute boolean results." ] }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -969,7 +1307,7 @@ "False" ] }, - "execution_count": 40, + "execution_count": 46, "metadata": {}, "output_type": "execute_result" } @@ -981,7 +1319,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -990,7 +1328,7 @@ "True" ] }, - "execution_count": 41, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -1006,12 +1344,12 @@ "source": [ "### 5.4. Mixing Nodes and Constants\n", "\n", - "You can mix DeepTrackNode instances and regular numbers:" + "You can mix DeepTrackNode instances and regular numbers." ] }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 48, "metadata": {}, "outputs": [ { @@ -1020,7 +1358,7 @@ "12" ] }, - "execution_count": 42, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -1032,7 +1370,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -1041,7 +1379,7 @@ "9" ] }, - "execution_count": 43, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } @@ -1062,7 +1400,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -1071,7 +1409,7 @@ "{'\\n@article{Midtvet2021Quantitative,\\n author = {Midtvedt, Benjamin and Helgadottir, Saga and Argun, Aykut and \\n Pineda, Jesús and Midtvedt, Daniel and Volpe, Giovanni},\\n title = {Quantitative digital microscopy with deep learning},\\n journal = {Applied Physics Reviews},\\n volume = {8},\\n number = {1},\\n pages = {011310},\\n year = {2021},\\n doi = {10.1063/5.0034891}\\n}\\n'}" ] }, - "execution_count": 44, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } diff --git a/tutorials/4-developers/DTDV411_style.ipynb b/tutorials/4-developers/DTDV411_style.ipynb index 5bb783797..79f3be12a 100644 --- a/tutorials/4-developers/DTDV411_style.ipynb +++ b/tutorials/4-developers/DTDV411_style.ipynb @@ -309,7 +309,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -355,10 +355,10 @@ "\n", " **Setter methods.**\n", "\n", - " `set_1(new_value: int) -> None`\n", + " `set_1(new_value) -> None`\n", " Set first attribute.\n", "\n", - " `set_2(new_value: str) -> None`\n", + " `set_2(new_value) -> None`\n", " Set second attribute.\n", "\n", " Examples\n",