Skip to content

Commit b5736ab

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: 17f1ce8 Pull Request resolved: #1041
1 parent 7e45bcc commit b5736ab

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

tensordict/base.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10456,17 +10456,88 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
1045610456
if pin_memory:
1045710457
storage = storage.pin_memory()
1045810458
storage_cast = storage.to(device, non_blocking=True)
10459+
if is_dynamo_compiling():
10460+
return self._to_reconstruct_compiled(
10461+
storage, storage_cast, device, num_threads, non_blocking
10462+
)
10463+
return self._to_reconstruct(
10464+
storage, storage_cast, device, num_threads, non_blocking
10465+
)
10466+
10467+
def _to_reconstruct(self, storage, storage_cast, device, num_threads, non_blocking):
1045910468
untyped_storage = storage_cast.untyped_storage()
1046010469

1046110470
def set_(x):
10471+
if x.is_nested:
10472+
if x.layout != torch.jagged:
10473+
raise RuntimeError(
10474+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10475+
"Please raise an issue on GitHub."
10476+
)
10477+
values = x._values
10478+
lengths = x._lengths
10479+
offsets = x._offsets
10480+
return torch.nested.nested_tensor_from_jagged(
10481+
set_(values),
10482+
offsets=set_(offsets),
10483+
lengths=set_(lengths) if lengths is not None else None,
10484+
)
1046210485
storage_offset = x.storage_offset()
1046310486
stride = x.stride()
10464-
return torch.empty_like(x, device=device).set_(
10487+
return x.new_empty((0,), device=device).set_(
1046510488
untyped_storage,
1046610489
size=x.shape,
1046710490
stride=stride,
1046810491
storage_offset=storage_offset,
1046910492
)
10493+
# return torch.empty_like(x, device=device).set_(
10494+
# untyped_storage,
10495+
# size=x.shape,
10496+
# stride=stride,
10497+
# storage_offset=storage_offset,
10498+
# )
10499+
10500+
result = self._fast_apply(
10501+
set_, device=torch.device(device), num_threads=num_threads
10502+
)
10503+
result._consolidated = {"storage": storage_cast}
10504+
if "metadata" in self._consolidated:
10505+
result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"])
10506+
if non_blocking in (False, None):
10507+
if device.type == "cuda" and non_blocking is False:
10508+
# sending to CUDA force sync
10509+
cuda_device = device
10510+
elif storage.device.type == "cuda":
10511+
# sending from cuda: need sync unless intentionally not asked for
10512+
cuda_device = storage.device.type
10513+
else:
10514+
cuda_device = None
10515+
if cuda_device is not None:
10516+
torch.cuda.current_stream(cuda_device).synchronize()
10517+
10518+
return result
10519+
10520+
def _to_reconstruct_compiled(self, storage, storage_cast, device, num_threads, non_blocking):
10521+
def set_(x):
10522+
if x.is_nested:
10523+
if x.layout != torch.jagged:
10524+
raise RuntimeError(
10525+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10526+
"Please raise an issue on GitHub."
10527+
)
10528+
values = x._values
10529+
lengths = x._lengths
10530+
offsets = x._offsets
10531+
return torch._nested_view_from_jagged(
10532+
set_(values),
10533+
set_(offsets),
10534+
x,
10535+
lengths=set_(lengths) if lengths is not None else None,
10536+
)
10537+
storage_offset = x.storage_offset()
10538+
stride = x.stride()
10539+
index_slice = slice(storage_offset, storage_offset + x.numel(), stride[0])
10540+
return storage_cast.view(x.dtype)[index_slice].view(x.type)
1047010541

1047110542
result = self._fast_apply(
1047210543
set_, device=torch.device(device), num_threads=num_threads

0 commit comments

Comments
 (0)