Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 18, 2024
1 parent 35df59e commit 748b2ba
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 35 deletions.
50 changes: 49 additions & 1 deletion test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
UnboundedContinuousTensorSpec,
UnboundedDiscreteTensorSpec,
)
from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec
from torchrl.data.utils import (
_make_ordinal_device,
check_no_exclusive_keys,
consolidate_spec,
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])
Expand Down Expand Up @@ -3689,6 +3693,50 @@ def test_sample(self):
assert nts.zero((2,)).shape == (2, 3, 4)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
def test_device_ordinal():
device = torch.device("cpu")
assert _make_ordinal_device(device) == torch.device("cpu")
device = torch.device("cuda")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = torch.device("cuda:0")
assert _make_ordinal_device(device) == torch.device("cuda:0")
device = None
assert _make_ordinal_device(device) is None

device = torch.device("cuda")
unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device)
assert unb.device == torch.device("cuda:0")
unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device)
assert unbd.device == torch.device("cuda:0")
bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device)
assert bound.device == torch.device("cuda:0")
oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device)
assert oneh.device == torch.device("cuda:0")
disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device)
assert disc.device == torch.device("cuda:0")
moneh = MultiOneHotDiscreteTensorSpec(
shape=(-1, 1, 2, 7), nvec=[3, 4], device=device
)
assert moneh.device == torch.device("cuda:0")
mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device)
assert mdisc.device == torch.device("cuda:0")
mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device)
assert mdisc.device == torch.device("cuda:0")

spec = CompositeSpec(
unb=unb,
unbd=unbd,
bound=bound,
oneh=oneh,
disc=disc,
moneh=moneh,
mdisc=mdisc,
shape=(-1, 1, 2),
)
assert spec.device == torch.device("cuda:0")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
16 changes: 11 additions & 5 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
Expand Down Expand Up @@ -820,10 +820,16 @@ def _get_devices(
env_device: torch.device,
device: torch.device,
):
device = torch.device(device) if device else device
storing_device = torch.device(storing_device) if storing_device else device
policy_device = torch.device(policy_device) if policy_device else device
env_device = torch.device(env_device) if env_device else device
device = _make_ordinal_device(torch.device(device) if device else device)
storing_device = _make_ordinal_device(
torch.device(storing_device) if storing_device else device
)
policy_device = _make_ordinal_device(
torch.device(policy_device) if policy_device else device
)
env_device = _make_ordinal_device(
torch.device(env_device) if env_device else device
)
if storing_device is None and (env_device == policy_device):
storing_device = env_device
return storing_device, policy_device, env_device
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
Writer,
WriterEnsemble,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.envs.transforms.transforms import _InvertTransform


Expand Down Expand Up @@ -1457,7 +1457,7 @@ def __init__(self, device: DEVICE_TYPING | None = None):
self.out = None
if device is None:
device = "cpu"
self.device = torch.device(device)
self.device = _make_ordinal_device(torch.device(device))

def __call__(self, list_of_tds):
if self.out is None:
Expand Down
9 changes: 7 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
INT_CLASSES,
tree_iter,
)
from torchrl.data.utils import _make_ordinal_device


class Storage:
Expand Down Expand Up @@ -405,7 +406,7 @@ def __init__(
else:
self._len = 0
self.device = (
torch.device(device)
_make_ordinal_device(torch.device(device))
if device != "auto"
else storage.device
if storage is not None
Expand Down Expand Up @@ -983,7 +984,11 @@ def __init__(
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
self.scratch_dir += "/"
self.device = torch.device(device) if device != "auto" else torch.device("cpu")
self.device = (
_make_ordinal_device(torch.device(device))
if device != "auto"
else torch.device("cpu")
)
if self.device.type != "cpu":
raise ValueError(
"Memory map device other than CPU isn't supported. To cast your data to the desired device, "
Expand Down
18 changes: 14 additions & 4 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
unravel_key,
)
from tensordict.utils import _getitem_batch_size, NestedKey

from torchrl._utils import get_binary_env_var
from torchrl.data.utils import _make_ordinal_device

DEVICE_TYPING = Union[torch.device, str, int]

Expand Down Expand Up @@ -91,7 +91,7 @@ def _default_dtype_and_device(
if dtype is None:
dtype = torch.get_default_dtype()
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
elif not allow_none_device:
device = torch.zeros(()).device
return dtype, device
Expand Down Expand Up @@ -536,6 +536,14 @@ def decorator(func):

return decorator

@property
def device(self) -> torch.device:
return self._device

@device.setter
def device(self, device: torch.device | None) -> None:
self._device = _make_ordinal_device(device)

def clear_device_(self):
"""A no-op for all leaf specs (which must have a device)."""
return self
Expand Down Expand Up @@ -3802,7 +3810,9 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
for key, value in kwargs.items():
self.set(key, value)

_device = torch.device(device) if device is not None else device
_device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
if len(kwargs):
for key, item in self.items():
if item is None:
Expand Down Expand Up @@ -3845,7 +3855,7 @@ def device(self, device: DEVICE_TYPING):
raise RuntimeError(
"To erase the device of a composite spec, call " "spec.clear_device_()."
)
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
self.to(device)

def clear_device_(self):
Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,11 @@ def _find_action_space(action_space):
f"action_space was not specified/not compatible and could not be retrieved from the value network. Got action_space={action_space}."
)
return action_space


def _make_ordinal_device(device: torch.device):
if device is None:
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
17 changes: 12 additions & 5 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
VERBOSE,
)
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.data.utils import (
_make_ordinal_device,
CloudpickleWrapper,
contains_lazy_spec,
DEVICE_TYPING,
)
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
from torchrl.envs.env_creator import get_env_metadata

Expand Down Expand Up @@ -346,7 +351,9 @@ def __init__(
"memmap and shared memory are mutually exclusive features."
)
self._batch_size = None
self._device = torch.device(device) if device is not None else device
self._device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
)
self._dummy_env_str = None
self._seeds = None
self.__dict__["_input_spec"] = None
Expand Down Expand Up @@ -835,7 +842,7 @@ def start(self) -> None:

def to(self, device: DEVICE_TYPING):
self._non_blocking = None
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self._device = device
Expand Down Expand Up @@ -1114,7 +1121,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down Expand Up @@ -1789,7 +1796,7 @@ def __getattr__(self, attr: str) -> Any:
)

def to(self, device: DEVICE_TYPING):
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
super().to(device)
Expand Down
8 changes: 4 additions & 4 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
_repr_by_depth,
Expand Down Expand Up @@ -154,7 +154,7 @@ def clone(self):

def to(self, device: DEVICE_TYPING) -> EnvMetaData:
if device is not None:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
device_map = {key: device for key in self.device_map}
tensordict = self.tensordict.contiguous().to(device)
specs = self.specs.to(device)
Expand Down Expand Up @@ -348,7 +348,7 @@ def __init__(
):
self.__dict__.setdefault("_batch_size", None)
if device is not None:
self.__dict__["_device"] = torch.device(device)
self.__dict__["_device"] = _make_ordinal_device(torch.device(device))
output_spec = self.__dict__.get("_output_spec")
if output_spec is not None:
self.__dict__["_output_spec"] = (
Expand Down Expand Up @@ -2947,7 +2947,7 @@ def __del__(self):
pass

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self.__dict__["_input_spec"] = self.input_spec.to(device).lock_()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/libs/habitat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import _classproperty
Expand Down Expand Up @@ -118,7 +118,7 @@ def _build_gym_env(self, env, pixels_only):
return super()._build_gym_env(env, pixels_only)

def to(self, device: DEVICE_TYPING) -> EnvBase:
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if device.type != "cuda":
raise ValueError("The device must be of type cuda for Habitat.")
device_num = device.index
Expand Down
25 changes: 17 additions & 8 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import _make_ordinal_device
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
from torchrl.envs.transforms.utils import (
Expand Down Expand Up @@ -2084,13 +2085,21 @@ def __new__(cls, *args, **kwargs):

def __init__(
self,
unsqueeze_dim: int,
dim: int = None,
allow_positive_dim: bool = False,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
**kwargs,
):
if "unsqueeze_dim" in kwargs:
warnings.warn(
"The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead."
)
dim = kwargs["unsqueeze_dim"]
elif dim is None:
raise TypeError("dim must be provided.")
if in_keys is None:
in_keys = [] # default
if out_keys is None:
Expand All @@ -2106,19 +2115,19 @@ def __init__(
out_keys_inv=out_keys_inv,
)
self.allow_positive_dim = allow_positive_dim
if unsqueeze_dim >= 0 and not allow_positive_dim:
if dim >= 0 and not allow_positive_dim:
raise RuntimeError(
"unsqueeze_dim should be smaller than 0 to accommodate for "
"dim should be smaller than 0 to accommodate for "
"envs of different batch_sizes. Turn allow_positive_dim to accommodate "
"for positive unsqueeze_dim."
)
self._unsqueeze_dim = unsqueeze_dim
self._dim = dim

@property
def unsqueeze_dim(self):
if self._unsqueeze_dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._unsqueeze_dim
return self._unsqueeze_dim
if self._dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._dim
return self._dim

def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
observation = observation.unsqueeze(self.unsqueeze_dim)
Expand Down Expand Up @@ -3808,7 +3817,7 @@ def __init__(
in_keys_inv=None,
out_keys_inv=None,
):
device = self.device = torch.device(device)
device = self.device = _make_ordinal_device(torch.device(device))
self.orig_device = (
torch.device(orig_device) if orig_device is not None else orig_device
)
Expand Down
4 changes: 2 additions & 2 deletions torchrl/trainers/helpers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
)
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data.utils import DEVICE_TYPING
from torchrl.data.utils import _make_ordinal_device, DEVICE_TYPING


def make_replay_buffer(
device: DEVICE_TYPING, cfg: "DictConfig" # noqa: F821
) -> ReplayBuffer: # noqa: F821
"""Builds a replay buffer using the config built from ReplayArgsConfig."""
device = torch.device(device)
device = _make_ordinal_device(torch.device(device))
if not cfg.prb:
sampler = RandomSampler()
else:
Expand Down

0 comments on commit 748b2ba

Please sign in to comment.