diff --git a/deeptrack/features.py b/deeptrack/features.py index a4bf2c70..fc6c13dd 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -222,7 +222,7 @@ def propagate_data_to_dependencies( "ChannelFirst2d", "Upscale", # TODO ***AL*** "NonOverlapping", # TODO ***AL*** - "Store", # TODO ***JH*** + "Store", "Squeeze", "Unsqueeze", "ExpandDims", @@ -8804,9 +8804,9 @@ def _resample_volume_position( class Store(Feature): """Store the output of a feature for reuse. - The `Store` feature evaluates a given feature and stores its output in an - internal dictionary. Subsequent calls with the same key will return the - stored value unless the `replace` parameter is set to `True`. This enables + The `Store` feature evaluates a given feature and stores its output in an + internal dictionary. Subsequent calls with the same key will return the + stored value unless the `replace` parameter is set to `True`. This enables caching and reuse of computed feature outputs. Parameters @@ -8815,10 +8815,10 @@ class Store(Feature): The feature to evaluate and store. key: Any The key used to identify the stored output. - replace: bool, optional - If `True`, replaces the stored value with a new computation. It defaults - to `False`. - **kwargs:: dict of str to Any + replace: PropertyLike[bool], optional + If `True`, replaces the stored value with the current computation. It + defaults to `False`. + **kwargs: dict of str to Any Additional keyword arguments passed to the parent `Feature` class. Attributes @@ -8852,22 +8852,26 @@ class Store(Feature): >>> cached_output = store_feature(None, key="example", replace=False) >>> print(cached_output == output) True + >>> print(cached_output == value_feature()) + False Retrieve the stored value recomputing: >>> value_feature.update() >>> cached_output = store_feature(None, key="example", replace=True) >>> print(cached_output == output) False + >>> print(cached_output == value_feature()) + True """ __distributed__: bool = False def __init__( - self: Store, + self: Feature, feature: Feature, key: Any, - replace: bool = False, + replace: PropertyLike[bool] = False, **kwargs: Any, ): """Initialize the Store feature. @@ -8878,8 +8882,8 @@ def __init__( The feature to evaluate and store. key: Any The key used to identify the stored output. - replace: bool, optional - If `True`, replaces the stored value with a new computation. + replace: PropertyLike[bool], optional + If `True`, replaces the stored value with a new computation. It defaults to `False`. **kwargs:: dict of str to Any Additional keyword arguments passed to the parent `Feature` class. @@ -8891,7 +8895,7 @@ def __init__( self._store: dict[Any, Image] = {} def get( - self: Store, + self: Feature, _: Any, key: Any, replace: bool, diff --git a/deeptrack/tests/test_features.py b/deeptrack/tests/test_features.py index 591547a6..94dd21f2 100644 --- a/deeptrack/tests/test_features.py +++ b/deeptrack/tests/test_features.py @@ -2408,10 +2408,35 @@ def test_Store(self): value_feature.update() cached_output = store_feature(None, key="example", replace=False) self.assertEqual(cached_output, output) + self.assertNotEqual(cached_output, value_feature()) value_feature.update() cached_output = store_feature(None, key="example", replace=True) self.assertNotEqual(cached_output, output) + self.assertEqual(cached_output, value_feature()) + + if TORCH_AVAILABLE: + + value_feature = features.Value(lambda: torch.rand(1)) + + store_feature = features.Store( + feature=value_feature, key="example" + ) + + output = store_feature(None, key="example", replace=False) + + value_feature.update() + cached_output = store_feature(None, key="example", replace=False) + torch.testing.assert_close(cached_output, output) + with self.assertRaises(AssertionError): + torch.testing.assert_close(cached_output, value_feature()) + + value_feature.update() + cached_output = store_feature(None, key="example", replace=True) + with self.assertRaises(AssertionError): + torch.testing.assert_close(cached_output, output) + torch.testing.assert_close(cached_output, value_feature()) + def test_Squeeze(self): diff --git a/tutorials/3-advanced-topics/DTAT391B_sources.folder.ipynb b/tutorials/3-advanced-topics/DTAT391B_sources.folder.ipynb index ef9e7d2d..464b37c5 100644 --- a/tutorials/3-advanced-topics/DTAT391B_sources.folder.ipynb +++ b/tutorials/3-advanced-topics/DTAT391B_sources.folder.ipynb @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -111,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -430,6 +430,13 @@ "List all image paths in train:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prepare an example path to one of the files:" + ] + }, { "cell_type": "code", "execution_count": 15,