Skip to content

[Feature] memmory mapped jagged tensors #1291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/vmoens/51/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
_LOCK_ERROR,
_maybe_correct_neg_dim,
_mismatch_keys,
_NON_STR_KEY_ERR,
_NON_STR_KEY_ERR,_nested_tensor_shape,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
_pass_through,
Expand Down Expand Up @@ -4701,7 +4701,19 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None):
def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok):
filename = None if prefix is None else str(prefix / f"{key}.memmap")
if value.is_nested:
shape = value._nested_tensor_size()
if value.layout is torch.strided:
shape = value._nested_tensor_size()
else:
offsets = value.offsets()
if offsets is None:
lengths = value.lengths()
else:
lengths = offsets.diff()
shapes = [lengths]
for s in value.shape[2:]:
shapes.append(torch.full_like(lengths, s))
shape = torch.stack(shapes, -1)
value = value.values()
# Make the shape a memmap tensor too
if prefix is not None:
shape_filename = Path(filename)
Expand All @@ -4713,6 +4725,7 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok)
existsok=existsok,
copy_data=True,
)

else:
shape = None
memmap_tensor = MemoryMappedTensor.from_tensor(
Expand Down Expand Up @@ -4795,7 +4808,7 @@ def _update_metadata(*, metadata, key, value, is_collection):
"shape": (
list(value.shape)
if not value.is_nested
else list(value._nested_tensor_size().shape)
else list(_nested_tensor_shape(value).shape)
),
"dtype": str(value.dtype),
"is_nested": value.is_nested,
Expand Down
46 changes: 33 additions & 13 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,11 @@ def zeros(cls, *args, **kwargs):

@classmethod
@overload
def empty(cls, *size, dtype=None, device=None, filename=None): ...
def empty(cls, *size, dtype=None, device=None, filename=None, layout=None): ...

@classmethod
@overload
def empty(cls, shape, *, dtype=None, device=None, filename=None): ...
def empty(cls, shape, *, dtype=None, device=None, filename=None, layout=None): ...

@classmethod
def empty(cls, *args, **kwargs):
Expand All @@ -539,8 +539,10 @@ def empty(cls, *args, **kwargs):
is provided, a handler is used.
existsok (bool, optional): whether it is ok to overwrite an existing file.
Defaults to ``False``.
layout (torch.layout): the layout of the tensor if nested. Only `None` (default), `torch.jagged` and
`torch.strided` are accepted.
"""
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
shape, device, dtype, _, filename, layout = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
Expand Down Expand Up @@ -573,11 +575,19 @@ def empty(cls, *args, **kwargs):
else:
raise RuntimeError(NESTED_TENSOR_ERR)
result = torch.frombuffer(memoryview(handler.buffer), dtype=dtype)
result = torch._nested_view_from_buffer(
result,
shape,
*offsets_strides,
)
if layout in (None, torch.strided):
result = torch._nested_view_from_buffer(
result,
shape,
*offsets_strides,
layout=layout,
)
else:
result = result.view((-1, *shape[0].tolist()))
result = torch.nested.nested_tensor_from_jagged(
result,
lengths=result[:, 0],
)
result = cls(result)
result._handler = handler
return result
Expand All @@ -597,11 +607,20 @@ def empty(cls, *args, **kwargs):
offsets_strides = func_offset_stride(shape)
else:
raise RuntimeError(NESTED_TENSOR_ERR)
result = torch._nested_view_from_buffer(
result,
shape,
*offsets_strides,
)
if layout in (None, torch.strided):
result = torch._nested_view_from_buffer(
result,
shape,
*offsets_strides,
)
else:
# TODO: we should not assume that the 2nd dim is the ragged one
result = result.view((-1, *shape[0, 1:].tolist()))
result = torch.nested.nested_tensor_from_jagged(
result,
lengths=result[:, 0],
)

result = cls(result)
result.filename = filename
return result
Expand Down Expand Up @@ -1030,6 +1049,7 @@ def _proc_args_const(*args, **kwargs):
kwargs.pop("dtype", None),
kwargs.pop("fill_value", None),
kwargs.pop("filename", None),
kwargs.pop("layout", None),
)


Expand Down
15 changes: 15 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3028,3 +3028,18 @@ def _check_is_unflatten(new_shape, old_shape, return_flatten_dim=False):
# j = len(new_shape) - j - 1
return out, (i, j)
return out

def _nested_tensor_shape(value):
if value.layout is torch.strided:
shape = value._nested_tensor_size()
else:
offsets = value.offsets()
if offsets is None:
lengths = value.lengths()
else:
lengths = offsets.diff()
shapes = [lengths]
for s in value.shape[2:]:
shapes.append(torch.full_like(lengths, s))
shape = torch.stack(shapes, -1)
return shape
43 changes: 25 additions & 18 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,15 +588,20 @@ def test_ne(self):
assert (memmap != ~memmap).all()


@pytest.mark.parametrize("layout", [torch.jagged, torch.strided])
class TestNestedTensor:
shape = torch.tensor([[2, 3], [2, 4], [3, 2]])
def shape(self, layout):
if layout is torch.strided:
return torch.tensor([[2, 3], [2, 4], [3, 2]])
return torch.tensor([[2, 3], [3, 3], [4, 3]])

@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
def test_with_filename(self, tmpdir):
def test_with_filename(self, tmpdir, layout):
filename = tmpdir + "/test_file2.memmap"
tensor = MemoryMappedTensor.empty(
self.shape, filename=filename, dtype=torch.int
self.shape(layout), filename=filename, dtype=torch.int, layout=layout,
)
assert tensor.layout is layout
assert isinstance(tensor, MemoryMappedTensor)
assert tensor.dtype == torch.int
tensor.fill_(2)
Expand All @@ -605,22 +610,24 @@ def test_with_filename(self, tmpdir):

filename = tmpdir + "/test_file0.memmap"
tensor = MemoryMappedTensor.zeros(
self.shape, filename=filename, dtype=torch.bool
self.shape(layout), filename=filename, dtype=torch.bool, layout=layout,
)
assert tensor.layout is layout
assert isinstance(tensor, MemoryMappedTensor)
assert tensor.dtype == torch.bool
assert tensor.filename is not None

filename = tmpdir + "/test_file1.memmap"
tensor = MemoryMappedTensor.ones(self.shape, filename=filename, dtype=torch.int)
tensor = MemoryMappedTensor.ones(self.shape(layout), filename=filename, dtype=torch.int, layout=layout)
assert tensor.layout is layout
assert type(tensor) is MemoryMappedTensor
assert tensor.dtype == torch.int
assert (tensor[0] == 1).all()
assert tensor.filename is not None

filename = tmpdir + "/test_file3.memmap"
tensor = torch.nested.nested_tensor(
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)]
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape(layout))]
)
memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=filename)
assert type(memmap_tensor) is MemoryMappedTensor
Expand All @@ -629,35 +636,35 @@ def test_with_filename(self, tmpdir):
assert (t1 == t2).all()

memmap_tensor2 = MemoryMappedTensor.from_filename(
filename, dtype=memmap_tensor.dtype, shape=self.shape
filename, dtype=memmap_tensor.dtype, shape=self.shape(layout)
)
assert type(memmap_tensor2) is MemoryMappedTensor
for t1, t2 in zip(memmap_tensor2, memmap_tensor):
assert t1.dtype == t2.dtype
assert (t1 == t2).all()

@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
def test_with_handler(self):
tensor = MemoryMappedTensor.empty(self.shape, dtype=torch.int)
def test_with_handler(self, layout):
tensor = MemoryMappedTensor.empty(self.shape(layout), dtype=torch.int, layout=layout)
assert isinstance(tensor, MemoryMappedTensor)
assert tensor.dtype == torch.int
tensor.fill_(2)
assert (tensor[0] == 2).all()
assert tensor._handler is not None

tensor = MemoryMappedTensor.zeros(self.shape, dtype=torch.bool)
tensor = MemoryMappedTensor.zeros(self.shape(layout), dtype=torch.bool, layout=layout)
assert isinstance(tensor, MemoryMappedTensor)
assert tensor.dtype == torch.bool
assert tensor._handler is not None

tensor = MemoryMappedTensor.ones(self.shape, dtype=torch.int)
tensor = MemoryMappedTensor.ones(self.shape(layout), dtype=torch.int, layout=layout)
assert type(tensor) is MemoryMappedTensor
assert tensor.dtype == torch.int
assert (tensor[0] == 1).all()
assert tensor._handler is not None

tensor = torch.nested.nested_tensor(
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)]
[torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape(layout))]
)
memmap_tensor = MemoryMappedTensor.from_tensor(tensor)
assert type(memmap_tensor) is MemoryMappedTensor
Expand All @@ -666,7 +673,7 @@ def test_with_handler(self):
assert (t1 == t2).all()

memmap_tensor2 = MemoryMappedTensor.from_handler(
memmap_tensor._handler, dtype=memmap_tensor.dtype, shape=self.shape
memmap_tensor._handler, dtype=memmap_tensor.dtype, shape=self.shape(layout), layout=layout
)
assert type(memmap_tensor2) is MemoryMappedTensor
for t1, t2 in zip(memmap_tensor2, memmap_tensor):
Expand All @@ -675,34 +682,34 @@ def test_with_handler(self):

@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
@pytest.mark.parametrize("with_filename", [False, True])
def test_from_storage(self, with_filename, tmpdir):
def test_from_storage(self, with_filename, tmpdir, layout):
if with_filename:
filename = Path(tmpdir) / "file.memmap"
filename = str(filename)
else:
filename = None
a = MemoryMappedTensor.from_tensor(
torch.arange(10, dtype=torch.float64), filename=filename
torch.arange(10, dtype=torch.float64), filename=filename, layout=layout,
)
assert type(a) is MemoryMappedTensor
shape = torch.tensor([[2, 2], [2, 3]])
b = MemoryMappedTensor.from_storage(
a.untyped_storage(), filename=filename, shape=shape, dtype=a.dtype
a.untyped_storage(), filename=filename, shape=shape, dtype=a.dtype, layout=layout,
)
assert type(b) is MemoryMappedTensor
assert (b._nested_tensor_size() == shape).all()
assert (b[0] == torch.arange(4).view(2, 2)).all()
assert (b[1] == torch.arange(4, 10).view(2, 3)).all()

@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
def test_save_td_with_nested(self, tmpdir):
def test_save_td_with_nested(self, tmpdir, layout):
td = TensorDict(
{
"a": torch.nested.nested_tensor(
[
torch.arange(12, dtype=torch.float64).view(3, 4),
torch.arange(15, dtype=torch.float64).view(3, 5),
]
], layout=layout,
)
},
batch_size=[2, 3],
Expand Down
Loading