@@ -10456,17 +10456,88 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
10456
10456
if pin_memory :
10457
10457
storage = storage .pin_memory ()
10458
10458
storage_cast = storage .to (device , non_blocking = True )
10459
+ if is_dynamo_compiling ():
10460
+ return self ._to_reconstruct_compiled (
10461
+ storage , storage_cast , device , num_threads , non_blocking
10462
+ )
10463
+ return self ._to_reconstruct (
10464
+ storage , storage_cast , device , num_threads , non_blocking
10465
+ )
10466
+
10467
+ def _to_reconstruct (self , storage , storage_cast , device , num_threads , non_blocking ):
10459
10468
untyped_storage = storage_cast .untyped_storage ()
10460
10469
10461
10470
def set_ (x ):
10471
+ if x .is_nested :
10472
+ if x .layout != torch .jagged :
10473
+ raise RuntimeError (
10474
+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10475
+ "Please raise an issue on GitHub."
10476
+ )
10477
+ values = x ._values
10478
+ lengths = x ._lengths
10479
+ offsets = x ._offsets
10480
+ return torch .nested .nested_tensor_from_jagged (
10481
+ set_ (values ),
10482
+ offsets = set_ (offsets ),
10483
+ lengths = set_ (lengths ) if lengths is not None else None ,
10484
+ )
10462
10485
storage_offset = x .storage_offset ()
10463
10486
stride = x .stride ()
10464
- return torch . empty_like ( x , device = device ).set_ (
10487
+ return x . new_empty (( 0 ,) , device = device ).set_ (
10465
10488
untyped_storage ,
10466
10489
size = x .shape ,
10467
10490
stride = stride ,
10468
10491
storage_offset = storage_offset ,
10469
10492
)
10493
+ # return torch.empty_like(x, device=device).set_(
10494
+ # untyped_storage,
10495
+ # size=x.shape,
10496
+ # stride=stride,
10497
+ # storage_offset=storage_offset,
10498
+ # )
10499
+
10500
+ result = self ._fast_apply (
10501
+ set_ , device = torch .device (device ), num_threads = num_threads
10502
+ )
10503
+ result ._consolidated = {"storage" : storage_cast }
10504
+ if "metadata" in self ._consolidated :
10505
+ result ._consolidated ["metadata" ] = deepcopy (self ._consolidated ["metadata" ])
10506
+ if non_blocking in (False , None ):
10507
+ if device .type == "cuda" and non_blocking is False :
10508
+ # sending to CUDA force sync
10509
+ cuda_device = device
10510
+ elif storage .device .type == "cuda" :
10511
+ # sending from cuda: need sync unless intentionally not asked for
10512
+ cuda_device = storage .device .type
10513
+ else :
10514
+ cuda_device = None
10515
+ if cuda_device is not None :
10516
+ torch .cuda .current_stream (cuda_device ).synchronize ()
10517
+
10518
+ return result
10519
+
10520
+ def _to_reconstruct_compiled (self , storage , storage_cast , device , num_threads , non_blocking ):
10521
+ def set_ (x ):
10522
+ if x .is_nested :
10523
+ if x .layout != torch .jagged :
10524
+ raise RuntimeError (
10525
+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10526
+ "Please raise an issue on GitHub."
10527
+ )
10528
+ values = x ._values
10529
+ lengths = x ._lengths
10530
+ offsets = x ._offsets
10531
+ return torch ._nested_view_from_jagged (
10532
+ set_ (values ),
10533
+ set_ (offsets ),
10534
+ x ,
10535
+ lengths = set_ (lengths ) if lengths is not None else None ,
10536
+ )
10537
+ storage_offset = x .storage_offset ()
10538
+ stride = x .stride ()
10539
+ index_slice = slice (storage_offset , storage_offset + x .numel (), stride [0 ])
10540
+ return storage_cast .view (x .dtype )[index_slice ].view (x .type )
10470
10541
10471
10542
result = self ._fast_apply (
10472
10543
set_ , device = torch .device (device ), num_threads = num_threads
0 commit comments