From 8fb658607b922239847b106913410429f4ae3e45 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 18 Jan 2024 11:15:57 +0000 Subject: [PATCH 1/4] init --- tensordict/_td.py | 2 + tensordict/base.py | 4 ++ tensordict/nn/common.py | 62 ++++++++++++++++------------- tensordict/nn/ensemble.py | 11 +++-- tensordict/nn/functional_modules.py | 37 +++++++++++++++++ tensordict/nn/probabilistic.py | 14 ++++--- tensordict/nn/sequence.py | 15 ++++--- test/test_nn.py | 10 ++++- 8 files changed, 109 insertions(+), 46 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 3d62a19cd..15926aa21 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -188,6 +188,7 @@ class TensorDict(TensorDictBase): _is_shared = False _is_memmap = False + @torch.compiler.disable def __init__( self, source: T | dict[str, CompatibleType], @@ -298,6 +299,7 @@ def is_empty(self): return True @as_decorator() + @torch.compiler.disable def to_module( self, module, diff --git a/tensordict/base.py b/tensordict/base.py index 86384dfa4..564ee0680 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -200,6 +200,7 @@ def __contains__(self, key: NestedKey) -> bool: "`key in tensordict.keys()` instead." ) + @torch.compiler.disable def __getitem__(self, index: IndexType) -> T: """Indexes all tensors according to the provided index. @@ -1789,6 +1790,7 @@ def entry_class(self, key: NestedKey) -> type: """ ... + @torch.compiler.disable def set( self, key: NestedKey, item: CompatibleType, inplace: bool = False, **kwargs: Any ) -> T: @@ -2056,6 +2058,7 @@ def _default_get( _KEY_ERROR.format(key, self.__class__.__name__, sorted(self.keys())) ) + @torch.compiler.disable def get( self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT ) -> CompatibleType: @@ -2441,6 +2444,7 @@ def copy_at_(self, tensordict: T, idx: IndexType) -> T: """See :obj:`TensorDictBase.update_at_`.""" return self.update_at_(tensordict, idx) + @torch.compiler.disable def is_empty(self) -> bool: """Checks if the tensordict contains any leaf.""" for _ in self.keys(True, True): diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 53df615aa..0e990b57f 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -18,6 +18,7 @@ from tensordict.functional import make_tensordict from tensordict.nn.functional_modules import ( + _auto_make_functional, _swap_state, extract_weights_and_buffers, is_functional, @@ -829,29 +830,33 @@ def reset_parameters_recursive( lambda x: x.detach().requires_grad_(), inplace=False ) - if not is_functional(self): + is_stateless = False + if _auto_make_functional() and not is_functional(self): make_functional(self, keep_params=True) - is_stateless = self._is_stateless - if is_stateless: - repopulate_module(self, sanitized_parameters) - else: - old_params = _swap_state( - self, - sanitized_parameters, - is_stateless=False, - return_old_tensordict=True, - ) + is_stateless = self._is_stateless + if is_stateless: + repopulate_module(self, sanitized_parameters) + else: + old_params = _swap_state( + self, + sanitized_parameters, + is_stateless=False, + return_old_tensordict=True, + ) - self._reset_parameters(self) + self._reset_parameters(self) - if is_stateless: - new_parameters = extract_weights_and_buffers(self) + if is_stateless: + new_parameters = extract_weights_and_buffers(self) + else: + new_parameters = _swap_state( + self, old_params, is_stateless=False, return_old_tensordict=True + ) + return new_parameters else: - new_parameters = _swap_state( - self, old_params, is_stateless=False, return_old_tensordict=True - ) - - return new_parameters + with sanitized_parameters.to_module(self): + self._reset_parameters(self) + return sanitized_parameters def _reset_parameters(self, module: nn.Module) -> None: for child in module.children(): @@ -865,10 +870,6 @@ def _reset_parameters(self, module: nn.Module) -> None: class TensorDictModule(TensorDictModuleBase): """A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. - By default, :class:`TensorDictModule` subclasses are always functional, - meaning that they support the ``td_module(input, params=params)`` function - call signature. - Args: module (Callable): a callable, typically a :class:`torch.nn.Module`, used to map the input to the output parameter space. Its forward method @@ -966,14 +967,15 @@ class TensorDictModule(TensorDictModuleBase): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule - >>> from tensordict.nn.functional_modules import make_functional >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) - >>> params = make_functional(td_module) - >>> td_functional = td_module(td.clone(), params=params) + >>> params = TensorDict.from_module(td_module) + >>> # functional API + >>> with params.to_module(td_module): + ... td_functional = td_module(td.clone()) >>> print(td_functional) TensorDict( fields={ @@ -1022,7 +1024,10 @@ class TensorDictModule(TensorDictModuleBase): batch_size=torch.Size([4]), device=None, is_shared=False) - >>> td_vmap = vmap(td_module, (None, 0))(td.clone(), params_repeat) + >>> def func(td, params): + ... with params.to_module(td_module): + ... return td_module(td) + >>> td_vmap = vmap(func, (None, 0))(td.clone(), params_repeat) >>> print(td_vmap) TensorDict( fields={ @@ -1089,7 +1094,8 @@ def __init__( ) self.module = module - make_functional(self, keep_params=True, return_params=False) + if _auto_make_functional(): + make_functional(self, keep_params=True, return_params=False) @property def is_functional(self) -> bool: diff --git a/tensordict/nn/ensemble.py b/tensordict/nn/ensemble.py index 2cefe015e..8c28dc3d7 100644 --- a/tensordict/nn/ensemble.py +++ b/tensordict/nn/ensemble.py @@ -8,7 +8,6 @@ import torch from tensordict import TensorDict from tensordict.nn.common import TensorDictBase, TensorDictModuleBase -from tensordict.nn.functional_modules import make_functional from tensordict.nn.params import TensorDictParams @@ -76,17 +75,21 @@ def __init__( super().__init__() self.in_keys = module.in_keys self.out_keys = module.out_keys - params_td = make_functional(module).expand(num_copies).to_tensordict() + params_td = TensorDict.from_module(module).expand(num_copies).to_tensordict() self.module = module if expand_input: - self.vmapped_forward = torch.vmap(self.module, (None, 0)) + self.vmapped_forward = torch.vmap(self._func_module_call, (None, 0)) else: - self.vmapped_forward = torch.vmap(self.module, 0) + self.vmapped_forward = torch.vmap(self._func_module_call, 0) self.reset_parameters_recursive(params_td) self.params_td = TensorDictParams(params_td) + def _func_module_call(self, input, params): + with params.to_module(self.module): + return self.module(input) + def forward(self, tensordict: TensorDict) -> TensorDict: return self.vmapped_forward(tensordict, self.params_td) diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 92b08af69..354b8af29 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -6,14 +6,17 @@ from __future__ import annotations import inspect +import os import re import types import warnings from copy import deepcopy +from distutils.util import strtobool from functools import wraps from typing import Any, Callable, Iterable import torch +from tensordict._contextlib import _DecoratorContextManager from tensordict._pytree import PYTREE_REGISTERED_TDS from tensordict._td import TensorDict @@ -29,6 +32,40 @@ # old torch version, passing pass + +AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "False")) + + +def _auto_make_functional(): + global AUTO_MAKE_FUNCTIONAL + return AUTO_MAKE_FUNCTIONAL + + +class _set_auto_make_functional(_DecoratorContextManager): + def __init__(self, mode): + self.mode = mode + + def __call__(self, func): + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + def clone(self): + return self.__class__(self.mode) + + def __enter__(self): + global AUTO_MAKE_FUNCTIONAL + self._saved_mode = AUTO_MAKE_FUNCTIONAL + AUTO_MAKE_FUNCTIONAL = self.mode + + def __exit__(self, exc_type, exc_val, exc_tb): + global AUTO_MAKE_FUNCTIONAL + AUTO_MAKE_FUNCTIONAL = self._saved_mode + + __base__setattr__ = nn.Module.__setattr__ diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index ae40e6cb2..739ea02b9 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -204,7 +204,6 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): ... TensorDictModule, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor - >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal >>> td = TensorDict( ... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3] @@ -225,8 +224,9 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): >>> td_module = ProbabilisticTensorDictSequential( ... module, normal_params, prob_module ... ) - >>> params = make_functional(td_module, funs_to_decorate=["forward", "get_dist", "log_prob"]) - >>> _ = td_module(td, params=params) + >>> params = TensorDict.from_module(td_module) + >>> with params.to_module(td_module): + ... _ = td_module(td) >>> print(td) TensorDict( fields={ @@ -240,13 +240,17 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): batch_size=torch.Size([3]), device=None, is_shared=False) - >>> dist = td_module.get_dist(td, params=params) + >>> with params.to_module(td_module): + ... dist = td_module.get_dist(td) >>> print(dist) Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])) >>> # we can also apply the module to the TensorDict with vmap >>> from torch import vmap >>> params = params.expand(4) - >>> td_vmap = vmap(td_module, (None, 0))(td, params) + >>> def func(td, params): + ... with params.to_module(td_module): + ... return td_module(td) + >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index a5265d849..0ca0603e4 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -38,10 +38,6 @@ class TensorDictSequential(TensorDictModule): """A sequence of TensorDictModules. - By default, :class:`TensorDictSequential` subclasses are always functional, - meaning that they support the ``td_module(input, params=params)`` function - call signature. - Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor each, this module will read and write over a tensordict by querying each of the input modules. When calling a :obj:`TensorDictSequencial` instance with a functional module, it is expected that the parameter lists (and @@ -92,7 +88,6 @@ class TensorDictSequential(TensorDictModule): ... TensorDictSequential, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor - >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> net1 = torch.nn.Linear(4, 8) @@ -115,8 +110,9 @@ class TensorDictSequential(TensorDictModule): ... module=module2, in_keys=["hidden"], out_keys=["output"] ... ) >>> td_module = TensorDictSequential(td_module1, td_module2) - >>> params = make_functional(td_module) - >>> _ = td_module(td, params=params) + >>> params = TensorDict.from_module(td_module) + >>> with params.to_module(td_module): + ... _ = td_module(td) >>> print(td) TensorDict( fields={ @@ -134,7 +130,10 @@ class TensorDictSequential(TensorDictModule): In the vmap case: >>> from torch import vmap >>> params = params.expand(4) - >>> td_vmap = vmap(td_module, (None, 0))(td, params) + >>> def func(td, params): + ... with params.to_module(td_module): + ... return td_module(td) + >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ diff --git a/test/test_nn.py b/test/test_nn.py index b894b3756..af36faec7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -32,7 +32,11 @@ ) from tensordict.nn.distributions.composite import CompositeDistribution from tensordict.nn.ensemble import EnsembleModule -from tensordict.nn.functional_modules import is_functional, make_functional +from tensordict.nn.functional_modules import ( + _set_auto_make_functional, + is_functional, + make_functional, +) from tensordict.nn.probabilistic import InteractionType, set_interaction_type from tensordict.nn.utils import Buffer, set_skip_existing, skip_existing from torch import distributions as d, nn @@ -166,6 +170,7 @@ def test_reset(self): nn.Sequential(nn.Tanh(), nn.Linear(1, 1), nn.Linear(2, 1)), ], ) + @_set_auto_make_functional(True) def test_reset_functional(self, net): torch.manual_seed(0) module = TensorDictModule(net, in_keys=["in"], out_keys=["out"]) @@ -198,6 +203,7 @@ def test_reset_functional(self, net): p.all() ), f"Discrepancy between returned weights and those in-place updated {p}" + @_set_auto_make_functional(True) def test_reset_functional_called_once(self): import unittest.mock @@ -397,6 +403,7 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys): @pytest.mark.skipif( not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" ) + @_set_auto_make_functional(True) def test_functional_before(self): torch.manual_seed(0) param_multiplier = 1 @@ -541,6 +548,7 @@ def test_functional_probabilistic(self): @pytest.mark.skipif( not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" ) + @_set_auto_make_functional(True) def test_functional_with_buffer(self): torch.manual_seed(0) param_multiplier = 1 From be802e99268f93952f158ff220c77e76c217e738 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 18 Jan 2024 12:36:16 +0000 Subject: [PATCH 2/4] amend --- tensordict/_td.py | 6 ++- tensordict/base.py | 8 ++-- tensordict/nn/common.py | 12 ++++- tensordict/nn/functional_modules.py | 36 +-------------- tensordict/nn/utils.py | 68 +++++++++++++++++++++++++++++ test/test_nn.py | 11 ++--- 6 files changed, 94 insertions(+), 47 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 15926aa21..494de8870 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -188,7 +188,7 @@ class TensorDict(TensorDictBase): _is_shared = False _is_memmap = False - @torch.compiler.disable + @torch.compiler.disable() def __init__( self, source: T | dict[str, CompatibleType], @@ -299,7 +299,7 @@ def is_empty(self): return True @as_decorator() - @torch.compiler.disable + @torch.compiler.disable() def to_module( self, module, @@ -429,6 +429,7 @@ def convert_type(x, y): swap.update(_swap) return swap + @torch.compiler.disable() def __ne__(self, other: object) -> T | bool: if _is_tensorclass(other): return other != self @@ -495,6 +496,7 @@ def __or__(self, other: object) -> T | bool: ) return False + @torch.compiler.disable() def __eq__(self, other: object) -> T | bool: if is_tensorclass(other): return other == self diff --git a/tensordict/base.py b/tensordict/base.py index 564ee0680..964539f5e 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -200,7 +200,7 @@ def __contains__(self, key: NestedKey) -> bool: "`key in tensordict.keys()` instead." ) - @torch.compiler.disable + @torch.compiler.disable() def __getitem__(self, index: IndexType) -> T: """Indexes all tensors according to the provided index. @@ -1790,7 +1790,7 @@ def entry_class(self, key: NestedKey) -> type: """ ... - @torch.compiler.disable + @torch.compiler.disable() def set( self, key: NestedKey, item: CompatibleType, inplace: bool = False, **kwargs: Any ) -> T: @@ -2058,7 +2058,7 @@ def _default_get( _KEY_ERROR.format(key, self.__class__.__name__, sorted(self.keys())) ) - @torch.compiler.disable + @torch.compiler.disable() def get( self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT ) -> CompatibleType: @@ -2444,7 +2444,7 @@ def copy_at_(self, tensordict: T, idx: IndexType) -> T: """See :obj:`TensorDictBase.update_at_`.""" return self.update_at_(tensordict, idx) - @torch.compiler.disable + @torch.compiler.disable() def is_empty(self) -> bool: """Checks if the tensordict contains any leaf.""" for _ in self.keys(True, True): diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 0e990b57f..fa82c3f76 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -18,7 +18,6 @@ from tensordict.functional import make_tensordict from tensordict.nn.functional_modules import ( - _auto_make_functional, _swap_state, extract_weights_and_buffers, is_functional, @@ -26,7 +25,11 @@ repopulate_module, ) -from tensordict.nn.utils import set_skip_existing +from tensordict.nn.utils import ( + _auto_make_functional, + _dispatch_td_nn_modules, + set_skip_existing, +) from tensordict.utils import implement_for, NestedKey from torch import nn, Tensor @@ -239,10 +242,15 @@ def __call__(self, func: Callable) -> Callable: "named 'tensordict'." ) break + if not _dispatch_td_nn_modules(): + return func @functools.wraps(func) def wrapper(_self, *args: Any, **kwargs: Any) -> Any: + if not _dispatch_td_nn_modules(): + return func(*args, **kwargs) + source = self.source if isinstance(source, str): source = getattr(_self, source) diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 354b8af29..0abf0351d 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -15,6 +15,8 @@ from functools import wraps from typing import Any, Callable, Iterable +import tensordict.nn.utils + import torch from tensordict._contextlib import _DecoratorContextManager from tensordict._pytree import PYTREE_REGISTERED_TDS @@ -32,40 +34,6 @@ # old torch version, passing pass - -AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "False")) - - -def _auto_make_functional(): - global AUTO_MAKE_FUNCTIONAL - return AUTO_MAKE_FUNCTIONAL - - -class _set_auto_make_functional(_DecoratorContextManager): - def __init__(self, mode): - self.mode = mode - - def __call__(self, func): - @wraps(func) - def new_func(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return new_func - - def clone(self): - return self.__class__(self.mode) - - def __enter__(self): - global AUTO_MAKE_FUNCTIONAL - self._saved_mode = AUTO_MAKE_FUNCTIONAL - AUTO_MAKE_FUNCTIONAL = self.mode - - def __exit__(self, exc_type, exc_val, exc_tb): - global AUTO_MAKE_FUNCTIONAL - AUTO_MAKE_FUNCTIONAL = self._saved_mode - - __base__setattr__ = nn.Module.__setattr__ diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index a15a71b28..51d43c858 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -7,12 +7,20 @@ import functools import inspect +import os +from distutils.util import strtobool from typing import Any, Callable import torch from torch import nn +AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "False")) + + +DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) + __all__ = ["mappings", "inv_softplus", "biased_softplus"] + _SKIP_EXISTING = False from tensordict._contextlib import _DecoratorContextManager @@ -287,3 +295,63 @@ def _rebuild_buffer(data, requires_grad, backward_hooks): # For backward compatibility in imports from tensordict.utils import Buffer # noqa + + +def _auto_make_functional(): + global DISPATCH_TDNN_MODULES + return AUTO_MAKE_FUNCTIONAL + + +class _set_auto_make_functional(_DecoratorContextManager): + def __init__(self, mode): + self.mode = mode + + def __call__(self, func): + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + def clone(self): + return self.__class__(self.mode) + + def __enter__(self): + global AUTO_MAKE_FUNCTIONAL + self._saved_mode = AUTO_MAKE_FUNCTIONAL + AUTO_MAKE_FUNCTIONAL = self.mode + + def __exit__(self, exc_type, exc_val, exc_tb): + global AUTO_MAKE_FUNCTIONAL + AUTO_MAKE_FUNCTIONAL = self._saved_mode + + +def _dispatch_td_nn_modules(): + global DISPATCH_TDNN_MODULES + return DISPATCH_TDNN_MODULES + + +class _set_dispatch_td_nn_modules(_DecoratorContextManager): + def __init__(self, mode): + self.mode = mode + + def __call__(self, func): + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + def clone(self): + return self.__class__(self.mode) + + def __enter__(self): + global DISPATCH_TDNN_MODULES + self._saved_mode = DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self.mode + + def __exit__(self, exc_type, exc_val, exc_tb): + global DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self._saved_mode diff --git a/test/test_nn.py b/test/test_nn.py index af36faec7..1aed8012d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -32,13 +32,14 @@ ) from tensordict.nn.distributions.composite import CompositeDistribution from tensordict.nn.ensemble import EnsembleModule -from tensordict.nn.functional_modules import ( +from tensordict.nn.functional_modules import is_functional, make_functional +from tensordict.nn.probabilistic import InteractionType, set_interaction_type +from tensordict.nn.utils import ( _set_auto_make_functional, - is_functional, - make_functional, + Buffer, + set_skip_existing, + skip_existing, ) -from tensordict.nn.probabilistic import InteractionType, set_interaction_type -from tensordict.nn.utils import Buffer, set_skip_existing, skip_existing from torch import distributions as d, nn from torch.distributions import Normal from torch.utils._pytree import tree_map From 25e87d3c8adc1702016c16a1fd83735e667d4a2f Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 18 Jan 2024 14:35:05 +0000 Subject: [PATCH 3/4] amend --- tensordict/nn/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 51d43c858..2095582cc 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -307,7 +307,7 @@ def __init__(self, mode): self.mode = mode def __call__(self, func): - @wraps(func) + @functools.wraps(func) def new_func(*args, **kwargs): with self: return func(*args, **kwargs) From 153325c68c2f64d5a9452351ce88377edd2a4764 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Feb 2024 21:18:02 +0000 Subject: [PATCH 4/4] amend --- tensordict/nn/functional_modules.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 0abf0351d..92b08af69 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -6,19 +6,14 @@ from __future__ import annotations import inspect -import os import re import types import warnings from copy import deepcopy -from distutils.util import strtobool from functools import wraps from typing import Any, Callable, Iterable -import tensordict.nn.utils - import torch -from tensordict._contextlib import _DecoratorContextManager from tensordict._pytree import PYTREE_REGISTERED_TDS from tensordict._td import TensorDict