diff --git a/tensordict/_td.py b/tensordict/_td.py index df0660191..6a8c0c358 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -190,6 +190,7 @@ class TensorDict(TensorDictBase): _is_shared = False _is_memmap = False + @torch.compiler.disable() def __init__( self, source: T | dict[str, CompatibleType], @@ -313,6 +314,7 @@ def is_empty(self): return False return True + @torch.compiler.disable() def _to_module( self, module, @@ -432,6 +434,7 @@ def _quick_set(swap_dict, swap_td): else: return TensorDict(_swap, batch_size=[], _run_checks=False) + @torch.compiler.disable() def __ne__(self, other: object) -> T | bool: if _is_tensorclass(other): return other != self @@ -498,6 +501,7 @@ def __or__(self, other: object) -> T | bool: ) return False + @torch.compiler.disable() def __eq__(self, other: object) -> T | bool: if is_tensorclass(other): return other == self diff --git a/tensordict/base.py b/tensordict/base.py index 51c734dd1..a248116ff 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -204,6 +204,7 @@ def __contains__(self, key: NestedKey) -> bool: "`key in tensordict.keys()` instead." ) + @torch.compiler.disable() def __getitem__(self, index: IndexType) -> T: """Indexes all tensors according to the provided index. @@ -2025,6 +2026,7 @@ def entry_class(self, key: NestedKey) -> type: """ ... + @torch.compiler.disable() def set( self, key: NestedKey, item: CompatibleType, inplace: bool = False, **kwargs: Any ) -> T: @@ -2290,6 +2292,7 @@ def _default_get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleT _KEY_ERROR.format(key, self.__class__.__name__, sorted(self.keys())) ) + @torch.compiler.disable() def get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType: """Gets the value stored with the input key. @@ -2697,6 +2700,7 @@ def copy_at_(self, tensordict: T, idx: IndexType) -> T: """See :obj:`TensorDictBase.update_at_`.""" return self.update_at_(tensordict, idx) + @torch.compiler.disable() def is_empty(self) -> bool: """Checks if the tensordict contains any leaf.""" for _ in self.keys(True, True): diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 10ba56cc7..ffff7f85d 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -841,6 +841,7 @@ def reset_parameters_recursive( lambda x: x.detach().requires_grad_(), inplace=False ) + is_stateless = False if _auto_make_functional() and not is_functional(self): make_functional(self, keep_params=True) is_stateless = self._is_stateless diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index bb422cc02..252a49016 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -210,7 +210,6 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): ... TensorDictModule, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor - >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal >>> td = TensorDict( ... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3] diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index eb7e14cef..0ca0603e4 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -88,7 +88,6 @@ class TensorDictSequential(TensorDictModule): ... TensorDictSequential, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor - >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> net1 = torch.nn.Linear(4, 8) diff --git a/test/test_nn.py b/test/test_nn.py index fd5aefeff..4dadd6f47 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -172,6 +172,7 @@ def test_reset(self): nn.Sequential(nn.Tanh(), nn.Linear(1, 1), nn.Linear(2, 1)), ], ) + @_set_auto_make_functional(True) def test_reset_functional(self, net): torch.manual_seed(0) module = TensorDictModule(net, in_keys=["in"], out_keys=["out"]) @@ -204,6 +205,7 @@ def test_reset_functional(self, net): p.all() ), f"Discrepancy between returned weights and those in-place updated {p}" + @_set_auto_make_functional(True) def test_reset_functional_called_once(self): import unittest.mock @@ -403,6 +405,7 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys): @pytest.mark.skipif( not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" ) + @_set_auto_make_functional(True) def test_functional_before(self): torch.manual_seed(0) param_multiplier = 1 @@ -569,6 +572,7 @@ def test_functional_probabilistic(self): @pytest.mark.skipif( not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" ) + @_set_auto_make_functional(True) def test_functional_with_buffer(self): torch.manual_seed(0) param_multiplier = 1