diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 6eae0a249..2088d764c 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -32,16 +32,23 @@ def from_metadata(metadata=metadata_dict, prefix=None): key: NonTensorData(data, batch_size=batch_size) for (key, (data, batch_size)) in non_tensor.items() } - for key, _ in leaves.items(): + for key in leaves.keys(): total_key = (key,) if prefix is None else prefix + (key,) if total_key[-1].startswith(""): nested_values = flat_key_values[total_key] + nested_lengths = None continue - if total_key[-1].startswith(""): + nested_lengths = flat_key_values[total_key] + continue + elif total_key[-1].startswith("", "") - value = torch.nested.nested_tensor_from_jagged(nested_values, offsets) + value = torch.nested.nested_tensor_from_jagged( + nested_values, offsets=offsets, lengths=nested_lengths + ) del nested_values + del nested_lengths else: value = flat_key_values[total_key] d[key] = value @@ -93,10 +100,16 @@ def from_metadata(metadata=metadata, prefix=None): value = value.view(local_shape) if key.startswith(""): nested_values = value + nested_lengths = None + continue + elif key.startswith(""): + nested_lengths = value continue elif key.startswith(""): offsets = value - value = torch.nested.nested_tensor_from_jagged(nested_values, offsets) + value = torch.nested.nested_tensor_from_jagged( + nested_values, offsets=offsets, lengths=nested_lengths + ) key = key.replace("", "") d[key] = value for k, v in metadata.items(): diff --git a/tensordict/base.py b/tensordict/base.py index 529146597..bc83464ed 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3590,7 +3590,10 @@ def assign( values = value.values() shape = [v if isinstance(v, int) else -1 for v in values.shape] # Get the offsets - offsets = value.offsets() + offsets = value._offsets + # Get the lengths + lengths = value._lengths + # Now we're saving the two tensors # We will rely on the fact that the writing order is preserved in python dict # (since python 3.7). Later, we will read the NJT then the NJT offset in that order @@ -3602,9 +3605,22 @@ def assign( metadata_dict, values.dtype, shape, - # values.device, flat_size, ) + # Lengths + if lengths is not None: + flat_key_values[ + _prefix_last_key(total_key, "") + ] = lengths + add_single_value( + lengths, + _prefix_last_key(key, ""), + metadata_dict, + lengths.dtype, + lengths.shape, + flat_size, + ) + # Offsets flat_key_values[_prefix_last_key(total_key, "")] = ( offsets ) @@ -3614,9 +3630,9 @@ def assign( metadata_dict, offsets.dtype, offsets.shape, - # offsets.device, flat_size, ) + else: raise NotImplementedError( "NST is not supported, please use layout=torch.jagged when building the nested tensor." @@ -3785,12 +3801,14 @@ def view_old_as_new(v, oldv): if num_threads > 0: def assign( + *, k, v, start, stop, njts, njts_offsets, + njts_lengths, storage=storage, non_blocking=non_blocking, ): @@ -3810,12 +3828,15 @@ def assign( new_v = new_v.view(shape) if k[-1].startswith(""): njts[k] = new_v + elif k[-1].startswith(""): + njts_lengths[k] = new_v elif k[-1].startswith(""): njts_offsets[k] = new_v flat_dict[k] = new_v njts = {} njts_offsets = {} + njts_lengths = {} if num_threads > 1: executor = ThreadPoolExecutor(num_threads) r = [] @@ -3823,12 +3844,13 @@ def assign( r.append( executor.submit( assign, - k, - v, - offsets[i], - offsets[i + 1], - njts, - njts_offsets, + k=k, + v=v, + start=offsets[i], + stop=offsets[i + 1], + njts=njts, + njts_offsets=njts_offsets, + njts_lengths=njts_lengths, ) ) if not return_early: @@ -3841,22 +3863,29 @@ def assign( else: for i, (k, v) in enumerate(flat_dict.items()): assign( - k, - v, - offsets[i], - offsets[i + 1], - njts, - njts_offsets, + k=k, + v=v, + start=offsets[i], + stop=offsets[i + 1], + njts=njts, + njts_offsets=njts_offsets, + njts_lengths=njts_lengths, ) for njt_key, njt_val in njts.items(): njt_key_offset = njt_key[:-1] + ( njt_key[-1].replace("", ""), ) + njt_key_lengths = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) val = torch.nested.nested_tensor_from_jagged( - njt_val, flat_dict[njt_key_offset] + njt_val, + offsets=flat_dict[njt_key_offset], + lengths=flat_dict.get(njt_key_lengths), ) del flat_dict[njt_key] del flat_dict[njt_key_offset] + flat_dict.pop(njt_key_lengths, None) newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) flat_dict[newkey] = val @@ -3896,13 +3925,20 @@ def _view_and_pad(tensor): elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes _nested_values = view_old_as_new(v, oldv) + nt_lengths = None + del flat_dict[k] + elif k[-1].startswith(""): + nt_lengths = view_old_as_new(v, oldv) del flat_dict[k] elif k[-1].startswith(""): newk = k[:-1] + (k[-1].replace("", ""),) nt_offsets = view_old_as_new(v, oldv) del flat_dict[k] + flat_dict[newk] = torch.nested.nested_tensor_from_jagged( - _nested_values, nt_offsets + _nested_values, + offsets=nt_offsets, + lengths=nt_lengths, ) # delete the nested value to make sure that if there was an # ordering mismatch we wouldn't be looking at the value key of diff --git a/tensordict/utils.py b/tensordict/utils.py index c1c60ed08..fd3140401 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1541,6 +1541,9 @@ def assert_close( continue elif not isinstance(input1, torch.Tensor): continue + if input1.is_nested: + input1 = input1._base + input2 = input2._base mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() mse = mse.div(input1.numel()).sqrt().item() diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 32cae51fa..a7afd21fc 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -9,6 +9,7 @@ import functools import gc +import importlib.util import json import os import pathlib @@ -22,8 +23,8 @@ import numpy as np import pytest -import tensordict.base as tensordict_base +import tensordict.base as tensordict_base import torch from _utils_internal import ( decompose, @@ -32,6 +33,7 @@ prod, TestTensorDictsBase, ) +from packaging import version from tensordict import ( get_defaults_to_none, @@ -42,6 +44,7 @@ TensorDict, ) from tensordict._lazy import _CustomOpTensorDict +from tensordict._reductions import _reduce_td from tensordict._td import _SubTensorDict, is_tensor_collection from tensordict._torch_func import _stack as stack_td from tensordict.base import _is_leaf_nontensor, _NESTED_TENSORS_AS_LISTS, TensorDictBase @@ -89,6 +92,11 @@ _has_h5py = True except ImportError: _has_h5py = False +TORCH_VERSION = version.parse(torch.__version__).base_version + +_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None + +_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0") _IS_OSX = platform.system() == "Darwin" _IS_WINDOWS = sys.platform == "win32" @@ -7852,6 +7860,7 @@ def test_consolidate(self, device, use_file, tmpdir): batch_size=[1, 3], ) td = LazyStackedTensorDict(*td.unbind(1), stack_dim=1) + if not use_file: td_c = td.consolidate() assert td_c.device == device @@ -7864,6 +7873,7 @@ def test_consolidate(self, device, use_file, tmpdir): assert type(td_c) == type(td) # noqa assert (td.to(td_c.device) == td_c).all() assert td_c["d"] == [["a string!"] * 3] + storage = td_c._consolidated["storage"] storage *= 0 assert (td.to(td_c.device) != td_c).any() @@ -7887,6 +7897,67 @@ def check_id(a, b): torch.utils._pytree.tree_map(check_id, td_c._consolidated, tdload._consolidated) assert tdload.is_consolidated() + @pytest.mark.skipif(not _v2_5, reason="v2.5 required for this test") + @pytest.mark.parametrize("device", [None, *get_available_devices()]) + @pytest.mark.parametrize("use_file", [False, True]) + def test_consolidate_njt(self, device, use_file, tmpdir): + td = TensorDict( + { + "a": torch.arange(3).expand(4, 3).clone(), + "b": {"c": torch.arange(3, dtype=torch.double).expand(4, 3).clone()}, + "d": "a string!", + "njt": torch.nested.nested_tensor_from_jagged( + torch.arange(10, device=device), + offsets=torch.tensor([0, 2, 5, 8, 10], device=device), + ), + "njt_lengths": torch.nested.nested_tensor_from_jagged( + torch.arange(10, device=device), + offsets=torch.tensor([0, 2, 5, 8, 10], device=device), + lengths=torch.tensor([2, 3, 3, 2], device=device), + ), + }, + device=device, + batch_size=[4], + ) + + if not use_file: + td_c = td.consolidate() + assert td_c.device == device + else: + filename = Path(tmpdir) / "file.mmap" + td_c = td.consolidate(filename=filename) + assert td_c.device == torch.device("cpu") + assert assert_allclose_td(TensorDict.from_consolidated(filename), td_c) + assert hasattr(td_c, "_consolidated") + assert type(td_c) == type(td) # noqa + assert td_c["d"] == "a string!" + with ( + pytest.raises(KeyError) + if td.device != td_c.device and device is not None + else contextlib.nullcontext() + ): + # njt.to(device) is currently broken when it has lengths + assert_allclose_td(td.to(td_c.device), td_c) + + tdload_make, tdload_data = _reduce_td(td) + tdload = tdload_make(*tdload_data) + assert (td == tdload).all() + + td_c = td.consolidate() + tdload_make, tdload_data = _reduce_td(td_c) + tdload = tdload_make(*tdload_data) + assert assert_allclose_td(td, tdload) + + def check_id(a, b): + if isinstance(a, (torch.Size, str)): + assert a == b + if isinstance(a, torch.Tensor): + assert (a == b).all() + + torch.utils._pytree.tree_map(check_id, td_c._consolidated, tdload._consolidated) + assert tdload.is_consolidated() + assert tdload["njt_lengths"]._lengths is not None + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected") def test_consolidate_to_device(self): td = TensorDict(